1use std::collections::HashMap;
13
14use chrono::{DateTime, Utc};
15use rusqlite::{params, Connection};
16use serde::{Deserialize, Serialize};
17
18use crate::error::{EngramError, Result};
19use crate::storage::queries::create_memory;
20use crate::types::{CreateMemoryInput, MemoryTier, MemoryType};
21
22#[derive(Debug, Clone)]
24pub struct ChunkingConfig {
25 pub max_messages: usize,
27 pub overlap_messages: usize,
29 pub max_chars: usize,
31 pub default_ttl_seconds: i64,
33}
34
35impl Default for ChunkingConfig {
36 fn default() -> Self {
37 Self {
38 max_messages: 10,
39 overlap_messages: 2,
40 max_chars: 8000,
41 default_ttl_seconds: 7 * 24 * 60 * 60, }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Message {
49 pub role: String,
51 pub content: String,
53 #[serde(default = "Utc::now")]
55 pub timestamp: DateTime<Utc>,
56 pub id: Option<String>,
58}
59
60#[derive(Debug, Clone)]
62pub struct ConversationChunk {
63 pub chunk_index: usize,
65 pub start_index: usize,
67 pub end_index: usize,
69 pub messages: Vec<Message>,
71 pub content: String,
73 pub char_count: usize,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Session {
80 pub session_id: String,
82 pub title: Option<String>,
84 pub agent_id: Option<String>,
86 pub started_at: DateTime<Utc>,
88 pub last_indexed_at: Option<DateTime<Utc>>,
90 pub message_count: i64,
92 pub chunk_count: i64,
94 pub workspace: String,
96 pub metadata: HashMap<String, serde_json::Value>,
98}
99
100pub fn chunk_conversation(messages: &[Message], config: &ChunkingConfig) -> Vec<ConversationChunk> {
115 if messages.is_empty() {
116 return vec![];
117 }
118
119 let mut chunks = Vec::new();
120 let mut chunk_start = 0;
121
122 while chunk_start < messages.len() {
123 let mut current_messages = Vec::new();
124 let mut current_chars = 0;
125 let mut i = chunk_start;
126
127 while i < messages.len() {
129 let msg = &messages[i];
130 let msg_chars = msg.content.len();
131
132 let (content, chars) = if msg_chars > config.max_chars {
134 let truncated = truncate_with_marker(&msg.content, config.max_chars);
135 (truncated.clone(), truncated.len())
136 } else {
137 (msg.content.clone(), msg_chars)
138 };
139
140 let would_exceed_chars =
142 current_chars + chars > config.max_chars && !current_messages.is_empty();
143 let would_exceed_messages = current_messages.len() >= config.max_messages;
144
145 if would_exceed_chars || would_exceed_messages {
146 break;
147 }
148
149 current_messages.push(Message {
151 role: msg.role.clone(),
152 content,
153 timestamp: msg.timestamp,
154 id: msg.id.clone(),
155 });
156 current_chars += chars;
157 i += 1;
158 }
159
160 if !current_messages.is_empty() {
162 let chunk_content = format_chunk_content(¤t_messages);
163 chunks.push(ConversationChunk {
164 chunk_index: chunks.len(),
165 start_index: chunk_start,
166 end_index: i,
167 messages: current_messages,
168 content: chunk_content.clone(),
169 char_count: chunk_content.len(),
170 });
171 }
172
173 let overlap = config.overlap_messages.min(i - chunk_start);
175 chunk_start = if i >= messages.len() {
176 messages.len() } else if i > chunk_start + overlap {
178 i - overlap
179 } else {
180 i };
182 }
183
184 chunks
185}
186
187fn truncate_with_marker(content: &str, max_chars: usize) -> String {
189 if content.len() <= max_chars {
190 return content.to_string();
191 }
192
193 let marker = "\n[...truncated...]\n";
195 let available = max_chars - marker.len();
196 let head_len = (available * 60) / 100;
197 let tail_len = available - head_len;
198
199 let head: String = content.chars().take(head_len).collect();
200 let tail: String = content
201 .chars()
202 .rev()
203 .take(tail_len)
204 .collect::<String>()
205 .chars()
206 .rev()
207 .collect();
208
209 format!("{}{}{}", head, marker, tail)
210}
211
212fn format_chunk_content(messages: &[Message]) -> String {
214 messages
215 .iter()
216 .map(|m| format!("[{}]: {}", m.role, m.content))
217 .collect::<Vec<_>>()
218 .join("\n\n")
219}
220
221pub fn index_conversation(
238 conn: &Connection,
239 session_id: &str,
240 messages: &[Message],
241 config: &ChunkingConfig,
242 workspace: Option<&str>,
243 title: Option<&str>,
244 agent_id: Option<&str>,
245) -> Result<Session> {
246 let now = Utc::now();
247 let workspace = workspace.unwrap_or("default");
248
249 let chunks = chunk_conversation(messages, config);
251
252 if chunks.is_empty() {
253 return Err(EngramError::InvalidInput(
254 "No messages to index".to_string(),
255 ));
256 }
257
258 let overlap_messages: Vec<&Message> = messages
260 .iter()
261 .rev()
262 .take(config.overlap_messages)
263 .collect::<Vec<_>>()
264 .into_iter()
265 .rev()
266 .collect();
267
268 let mut session_metadata = HashMap::new();
269 session_metadata.insert(
270 "overlap_messages".to_string(),
271 serde_json::to_value(&overlap_messages).unwrap_or_default(),
272 );
273 let metadata_json = serde_json::to_string(&session_metadata)?;
274
275 let started_at = messages.first().map(|m| m.timestamp).unwrap_or(now);
277
278 conn.execute(
279 r#"
280 INSERT INTO sessions (session_id, title, agent_id, started_at, last_indexed_at,
281 message_count, chunk_count, workspace, metadata)
282 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
283 ON CONFLICT(session_id) DO UPDATE SET
284 title = COALESCE(excluded.title, sessions.title),
285 last_indexed_at = excluded.last_indexed_at,
286 message_count = excluded.message_count,
287 chunk_count = excluded.chunk_count,
288 metadata = excluded.metadata
289 "#,
290 params![
291 session_id,
292 title,
293 agent_id,
294 started_at.to_rfc3339(),
295 now.to_rfc3339(),
296 messages.len() as i64,
297 chunks.len() as i64,
298 workspace,
299 metadata_json,
300 ],
301 )?;
302
303 conn.execute(
305 "DELETE FROM session_chunks WHERE session_id = ?",
306 params![session_id],
307 )?;
308
309 for chunk in &chunks {
311 let mut metadata = HashMap::new();
312 metadata.insert("session_id".to_string(), serde_json::json!(session_id));
313 metadata.insert(
314 "chunk_index".to_string(),
315 serde_json::json!(chunk.chunk_index),
316 );
317 metadata.insert(
318 "start_message".to_string(),
319 serde_json::json!(chunk.start_index),
320 );
321 metadata.insert(
322 "end_message".to_string(),
323 serde_json::json!(chunk.end_index),
324 );
325 metadata.insert(
326 "message_count".to_string(),
327 serde_json::json!(chunk.messages.len()),
328 );
329
330 let input = CreateMemoryInput {
331 content: chunk.content.clone(),
332 memory_type: MemoryType::TranscriptChunk,
333 tags: vec!["transcript".to_string(), format!("session:{}", session_id)],
334 metadata,
335 importance: Some(0.3), scope: Default::default(),
337 workspace: Some(workspace.to_string()),
338 tier: MemoryTier::Daily, defer_embedding: false,
340 ttl_seconds: Some(config.default_ttl_seconds),
341 dedup_mode: Default::default(),
342 dedup_threshold: None,
343 event_time: None,
344 event_duration_seconds: None,
345 trigger_pattern: None,
346 summary_of_id: None,
347 };
348
349 let memory = create_memory(conn, &input)?;
350
351 conn.execute(
353 r#"
354 INSERT INTO session_chunks (session_id, memory_id, chunk_index,
355 start_message_index, end_message_index)
356 VALUES (?, ?, ?, ?, ?)
357 "#,
358 params![
359 session_id,
360 memory.id,
361 chunk.chunk_index as i64,
362 chunk.start_index as i64,
363 chunk.end_index as i64,
364 ],
365 )?;
366 }
367
368 tracing::info!(
369 session_id = session_id,
370 message_count = messages.len(),
371 chunk_count = chunks.len(),
372 "Indexed conversation"
373 );
374
375 Ok(Session {
376 session_id: session_id.to_string(),
377 title: title.map(String::from),
378 agent_id: agent_id.map(String::from),
379 started_at,
380 last_indexed_at: Some(now),
381 message_count: messages.len() as i64,
382 chunk_count: chunks.len() as i64,
383 workspace: workspace.to_string(),
384 metadata: HashMap::new(),
385 })
386}
387
388pub fn index_conversation_delta(
401 conn: &Connection,
402 session_id: &str,
403 new_messages: &[Message],
404 config: &ChunkingConfig,
405) -> Result<Session> {
406 let session: Option<Session> = conn
408 .query_row(
409 "SELECT session_id, title, agent_id, started_at, last_indexed_at,
410 message_count, chunk_count, workspace, metadata
411 FROM sessions WHERE session_id = ?",
412 params![session_id],
413 |row| {
414 let started_at: String = row.get(3)?;
415 let last_indexed_at: Option<String> = row.get(4)?;
416 let metadata_str: String = row.get(8)?;
417 Ok(Session {
418 session_id: row.get(0)?,
419 title: row.get(1)?,
420 agent_id: row.get(2)?,
421 started_at: DateTime::parse_from_rfc3339(&started_at)
422 .map(|dt| dt.with_timezone(&Utc))
423 .unwrap_or_else(|_| Utc::now()),
424 last_indexed_at: last_indexed_at.and_then(|s| {
425 DateTime::parse_from_rfc3339(&s)
426 .map(|dt| dt.with_timezone(&Utc))
427 .ok()
428 }),
429 message_count: row.get(5)?,
430 chunk_count: row.get(6)?,
431 workspace: row.get(7)?,
432 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
433 })
434 },
435 )
436 .ok();
437
438 match session {
439 Some(existing) => {
440 let last_chunk_end: i64 = conn
442 .query_row(
443 "SELECT COALESCE(MAX(end_message_index), 0) FROM session_chunks WHERE session_id = ?",
444 params![session_id],
445 |row| row.get(0),
446 )
447 .unwrap_or(0);
448
449 let overlap_messages: Vec<Message> = existing
451 .metadata
452 .get("overlap_messages")
453 .and_then(|v| serde_json::from_value(v.clone()).ok())
454 .unwrap_or_default();
455
456 let overlap_count = overlap_messages.len();
458
459 let mut all_messages = overlap_messages;
461 all_messages.extend(new_messages.iter().cloned());
462
463 let chunks = chunk_conversation(&all_messages, config);
465
466 if chunks.is_empty() {
467 return Ok(existing);
468 }
469
470 let now = Utc::now();
471 let new_message_count = existing.message_count + new_messages.len() as i64;
472 let starting_chunk_index = existing.chunk_count;
473
474 let new_overlap: Vec<&Message> = all_messages
477 .iter()
478 .rev()
479 .take(config.overlap_messages)
480 .collect::<Vec<_>>()
481 .into_iter()
482 .rev()
483 .collect();
484
485 let mut updated_metadata = existing.metadata.clone();
487 updated_metadata.insert(
488 "overlap_messages".to_string(),
489 serde_json::to_value(&new_overlap).unwrap_or_default(),
490 );
491 let metadata_json = serde_json::to_string(&updated_metadata)?;
492
493 conn.execute(
495 "UPDATE sessions SET last_indexed_at = ?, message_count = ?, chunk_count = ?, metadata = ? WHERE session_id = ?",
496 params![
497 now.to_rfc3339(),
498 new_message_count,
499 existing.chunk_count + chunks.len() as i64,
500 metadata_json,
501 session_id,
502 ],
503 )?;
504
505 let base_offset = (last_chunk_end as usize).saturating_sub(overlap_count);
511
512 for (i, chunk) in chunks.iter().enumerate() {
514 let chunk_index = starting_chunk_index as usize + i;
515
516 let global_start = chunk.start_index + base_offset;
518 let global_end = chunk.end_index + base_offset;
519
520 let mut metadata = HashMap::new();
521 metadata.insert("session_id".to_string(), serde_json::json!(session_id));
522 metadata.insert("chunk_index".to_string(), serde_json::json!(chunk_index));
523 metadata.insert("start_message".to_string(), serde_json::json!(global_start));
524 metadata.insert("end_message".to_string(), serde_json::json!(global_end));
525 metadata.insert(
526 "message_count".to_string(),
527 serde_json::json!(chunk.messages.len()),
528 );
529
530 let input = CreateMemoryInput {
531 content: chunk.content.clone(),
532 memory_type: MemoryType::TranscriptChunk,
533 tags: vec!["transcript".to_string(), format!("session:{}", session_id)],
534 metadata,
535 importance: Some(0.3),
536 scope: Default::default(),
537 workspace: Some(existing.workspace.clone()),
538 tier: MemoryTier::Daily,
539 defer_embedding: false,
540 ttl_seconds: Some(config.default_ttl_seconds),
541 dedup_mode: Default::default(),
542 dedup_threshold: None,
543 event_time: None,
544 event_duration_seconds: None,
545 trigger_pattern: None,
546 summary_of_id: None,
547 };
548
549 let memory = create_memory(conn, &input)?;
550
551 conn.execute(
552 r#"
553 INSERT INTO session_chunks (session_id, memory_id, chunk_index,
554 start_message_index, end_message_index)
555 VALUES (?, ?, ?, ?, ?)
556 "#,
557 params![
558 session_id,
559 memory.id,
560 chunk_index as i64,
561 global_start as i64,
562 global_end as i64,
563 ],
564 )?;
565 }
566
567 Ok(Session {
568 message_count: new_message_count,
569 chunk_count: existing.chunk_count + chunks.len() as i64,
570 last_indexed_at: Some(now),
571 ..existing
572 })
573 }
574 None => {
575 index_conversation(conn, session_id, new_messages, config, None, None, None)
577 }
578 }
579}
580
581pub fn get_session(conn: &Connection, session_id: &str) -> Result<Session> {
583 conn.query_row(
584 "SELECT session_id, title, agent_id, started_at, last_indexed_at,
585 message_count, chunk_count, workspace, metadata
586 FROM sessions WHERE session_id = ?",
587 params![session_id],
588 |row| {
589 let started_at: String = row.get(3)?;
590 let last_indexed_at: Option<String> = row.get(4)?;
591 let metadata_str: String = row.get(8)?;
592 Ok(Session {
593 session_id: row.get(0)?,
594 title: row.get(1)?,
595 agent_id: row.get(2)?,
596 started_at: DateTime::parse_from_rfc3339(&started_at)
597 .map(|dt| dt.with_timezone(&Utc))
598 .unwrap_or_else(|_| Utc::now()),
599 last_indexed_at: last_indexed_at.and_then(|s| {
600 DateTime::parse_from_rfc3339(&s)
601 .map(|dt| dt.with_timezone(&Utc))
602 .ok()
603 }),
604 message_count: row.get(5)?,
605 chunk_count: row.get(6)?,
606 workspace: row.get(7)?,
607 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
608 })
609 },
610 )
611 .map_err(|_| EngramError::NotFound(0))
612}
613
614pub fn list_sessions(
616 conn: &Connection,
617 workspace: Option<&str>,
618 limit: i64,
619) -> Result<Vec<Session>> {
620 let mut sql = String::from(
621 "SELECT session_id, title, agent_id, started_at, last_indexed_at,
622 message_count, chunk_count, workspace, metadata
623 FROM sessions",
624 );
625
626 let mut params: Vec<Box<dyn rusqlite::ToSql>> = vec![];
627
628 if let Some(ws) = workspace {
629 sql.push_str(" WHERE workspace = ?");
630 params.push(Box::new(ws.to_string()));
631 }
632
633 sql.push_str(" ORDER BY started_at DESC LIMIT ?");
634 params.push(Box::new(limit));
635
636 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|b| b.as_ref()).collect();
637 let mut stmt = conn.prepare(&sql)?;
638
639 let sessions = stmt
640 .query_map(param_refs.as_slice(), |row| {
641 let started_at: String = row.get(3)?;
642 let last_indexed_at: Option<String> = row.get(4)?;
643 let metadata_str: String = row.get(8)?;
644 Ok(Session {
645 session_id: row.get(0)?,
646 title: row.get(1)?,
647 agent_id: row.get(2)?,
648 started_at: DateTime::parse_from_rfc3339(&started_at)
649 .map(|dt| dt.with_timezone(&Utc))
650 .unwrap_or_else(|_| Utc::now()),
651 last_indexed_at: last_indexed_at.and_then(|s| {
652 DateTime::parse_from_rfc3339(&s)
653 .map(|dt| dt.with_timezone(&Utc))
654 .ok()
655 }),
656 message_count: row.get(5)?,
657 chunk_count: row.get(6)?,
658 workspace: row.get(7)?,
659 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
660 })
661 })?
662 .filter_map(|r| r.ok())
663 .collect();
664
665 Ok(sessions)
666}
667
668pub fn delete_session(conn: &Connection, session_id: &str) -> Result<()> {
670 conn.execute(
672 r#"
673 UPDATE memories SET valid_to = datetime('now')
674 WHERE id IN (SELECT memory_id FROM session_chunks WHERE session_id = ?)
675 "#,
676 params![session_id],
677 )?;
678
679 conn.execute(
681 "DELETE FROM session_chunks WHERE session_id = ?",
682 params![session_id],
683 )?;
684
685 conn.execute(
687 "DELETE FROM sessions WHERE session_id = ?",
688 params![session_id],
689 )?;
690
691 Ok(())
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697
698 fn make_messages(count: usize, char_len: usize) -> Vec<Message> {
699 (0..count)
700 .map(|i| Message {
701 role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(),
702 content: format!("Message {} {}", i, "x".repeat(char_len)),
703 timestamp: Utc::now(),
704 id: Some(format!("msg-{}", i)),
705 })
706 .collect()
707 }
708
709 #[test]
710 fn test_chunk_by_message_count() {
711 let config = ChunkingConfig {
712 max_messages: 3,
713 overlap_messages: 1,
714 max_chars: 100000, ..Default::default()
716 };
717
718 let messages = make_messages(7, 10);
719 let chunks = chunk_conversation(&messages, &config);
720
721 assert_eq!(chunks.len(), 3);
723 assert_eq!(chunks[0].start_index, 0);
724 assert_eq!(chunks[0].end_index, 3);
725 assert_eq!(chunks[1].start_index, 2); assert_eq!(chunks[1].end_index, 5);
727 assert_eq!(chunks[2].start_index, 4); assert_eq!(chunks[2].end_index, 7);
729 }
730
731 #[test]
732 fn test_chunk_by_char_count() {
733 let config = ChunkingConfig {
734 max_messages: 100, overlap_messages: 1,
736 max_chars: 100, ..Default::default()
738 };
739
740 let messages = make_messages(9, 20);
742 let chunks = chunk_conversation(&messages, &config);
743
744 assert!(chunks.len() > 1);
746 for chunk in &chunks {
747 assert!(chunk.char_count <= config.max_chars + 50); }
749 }
750
751 #[test]
752 fn test_truncate_long_message() {
753 let config = ChunkingConfig {
754 max_messages: 10,
755 overlap_messages: 1,
756 max_chars: 100,
757 ..Default::default()
758 };
759
760 let long_content = "x".repeat(200);
761 let messages = vec![Message {
762 role: "user".to_string(),
763 content: long_content,
764 timestamp: Utc::now(),
765 id: None,
766 }];
767
768 let chunks = chunk_conversation(&messages, &config);
769
770 assert_eq!(chunks.len(), 1);
771 assert!(chunks[0].messages[0].content.contains("[...truncated...]"));
772 assert!(chunks[0].messages[0].content.len() <= config.max_chars);
773 }
774
775 #[test]
776 fn test_empty_conversation() {
777 let config = ChunkingConfig::default();
778 let messages: Vec<Message> = vec![];
779 let chunks = chunk_conversation(&messages, &config);
780
781 assert!(chunks.is_empty());
782 }
783
784 #[test]
785 fn test_format_chunk_content() {
786 let messages = vec![
787 Message {
788 role: "user".to_string(),
789 content: "Hello".to_string(),
790 timestamp: Utc::now(),
791 id: None,
792 },
793 Message {
794 role: "assistant".to_string(),
795 content: "Hi there!".to_string(),
796 timestamp: Utc::now(),
797 id: None,
798 },
799 ];
800
801 let content = format_chunk_content(&messages);
802 assert!(content.contains("[user]: Hello"));
803 assert!(content.contains("[assistant]: Hi there!"));
804 }
805}