1use std::collections::HashMap;
13
14use chrono::{DateTime, Utc};
15use rusqlite::{params, Connection};
16use serde::{Deserialize, Serialize};
17
18use crate::error::{EngramError, Result};
19use crate::storage::queries::create_memory;
20use crate::types::{CreateMemoryInput, MemoryTier, MemoryType};
21
22#[derive(Debug, Clone)]
24pub struct ChunkingConfig {
25 pub max_messages: usize,
27 pub overlap_messages: usize,
29 pub max_chars: usize,
31 pub default_ttl_seconds: i64,
33}
34
35impl Default for ChunkingConfig {
36 fn default() -> Self {
37 Self {
38 max_messages: 10,
39 overlap_messages: 2,
40 max_chars: 8000,
41 default_ttl_seconds: 7 * 24 * 60 * 60, }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Message {
49 pub role: String,
51 pub content: String,
53 #[serde(default = "Utc::now")]
55 pub timestamp: DateTime<Utc>,
56 pub id: Option<String>,
58}
59
60#[derive(Debug, Clone)]
62pub struct ConversationChunk {
63 pub chunk_index: usize,
65 pub start_index: usize,
67 pub end_index: usize,
69 pub messages: Vec<Message>,
71 pub content: String,
73 pub char_count: usize,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Session {
80 pub session_id: String,
82 pub title: Option<String>,
84 pub agent_id: Option<String>,
86 pub started_at: DateTime<Utc>,
88 pub last_indexed_at: Option<DateTime<Utc>>,
90 pub message_count: i64,
92 pub chunk_count: i64,
94 pub workspace: String,
96 pub metadata: HashMap<String, serde_json::Value>,
98}
99
100pub fn chunk_conversation(messages: &[Message], config: &ChunkingConfig) -> Vec<ConversationChunk> {
115 if messages.is_empty() {
116 return vec![];
117 }
118
119 let mut chunks = Vec::new();
120 let mut chunk_start = 0;
121
122 while chunk_start < messages.len() {
123 let mut current_messages = Vec::new();
124 let mut current_chars = 0;
125 let mut i = chunk_start;
126
127 while i < messages.len() {
129 let msg = &messages[i];
130 let msg_chars = msg.content.len();
131
132 let (content, chars) = if msg_chars > config.max_chars {
134 let truncated = truncate_with_marker(&msg.content, config.max_chars);
135 (truncated.clone(), truncated.len())
136 } else {
137 (msg.content.clone(), msg_chars)
138 };
139
140 let would_exceed_chars =
142 current_chars + chars > config.max_chars && !current_messages.is_empty();
143 let would_exceed_messages = current_messages.len() >= config.max_messages;
144
145 if would_exceed_chars || would_exceed_messages {
146 break;
147 }
148
149 current_messages.push(Message {
151 role: msg.role.clone(),
152 content,
153 timestamp: msg.timestamp,
154 id: msg.id.clone(),
155 });
156 current_chars += chars;
157 i += 1;
158 }
159
160 if !current_messages.is_empty() {
162 let chunk_content = format_chunk_content(¤t_messages);
163 chunks.push(ConversationChunk {
164 chunk_index: chunks.len(),
165 start_index: chunk_start,
166 end_index: i,
167 messages: current_messages,
168 content: chunk_content.clone(),
169 char_count: chunk_content.len(),
170 });
171 }
172
173 let overlap = config.overlap_messages.min(i - chunk_start);
175 chunk_start = if i >= messages.len() {
176 messages.len() } else if i > chunk_start + overlap {
178 i - overlap
179 } else {
180 i };
182 }
183
184 chunks
185}
186
187fn truncate_with_marker(content: &str, max_chars: usize) -> String {
189 if content.len() <= max_chars {
190 return content.to_string();
191 }
192
193 let marker = "\n[...truncated...]\n";
195 let available = max_chars - marker.len();
196 let head_len = (available * 60) / 100;
197 let tail_len = available - head_len;
198
199 let head: String = content.chars().take(head_len).collect();
200 let tail: String = content
201 .chars()
202 .rev()
203 .take(tail_len)
204 .collect::<String>()
205 .chars()
206 .rev()
207 .collect();
208
209 format!("{}{}{}", head, marker, tail)
210}
211
212fn format_chunk_content(messages: &[Message]) -> String {
214 messages
215 .iter()
216 .map(|m| format!("[{}]: {}", m.role, m.content))
217 .collect::<Vec<_>>()
218 .join("\n\n")
219}
220
221pub fn index_conversation(
238 conn: &Connection,
239 session_id: &str,
240 messages: &[Message],
241 config: &ChunkingConfig,
242 workspace: Option<&str>,
243 title: Option<&str>,
244 agent_id: Option<&str>,
245) -> Result<Session> {
246 let now = Utc::now();
247 let workspace = workspace.unwrap_or("default");
248
249 let chunks = chunk_conversation(messages, config);
251
252 if chunks.is_empty() {
253 return Err(EngramError::InvalidInput(
254 "No messages to index".to_string(),
255 ));
256 }
257
258 let overlap_messages: Vec<&Message> = messages
260 .iter()
261 .rev()
262 .take(config.overlap_messages)
263 .collect::<Vec<_>>()
264 .into_iter()
265 .rev()
266 .collect();
267
268 let mut session_metadata = HashMap::new();
269 session_metadata.insert(
270 "overlap_messages".to_string(),
271 serde_json::to_value(&overlap_messages).unwrap_or_default(),
272 );
273 let metadata_json = serde_json::to_string(&session_metadata)?;
274
275 let started_at = messages.first().map(|m| m.timestamp).unwrap_or(now);
277
278 conn.execute(
279 r#"
280 INSERT INTO sessions (session_id, title, agent_id, started_at, last_indexed_at,
281 message_count, chunk_count, workspace, metadata)
282 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
283 ON CONFLICT(session_id) DO UPDATE SET
284 title = COALESCE(excluded.title, sessions.title),
285 last_indexed_at = excluded.last_indexed_at,
286 message_count = excluded.message_count,
287 chunk_count = excluded.chunk_count,
288 metadata = excluded.metadata
289 "#,
290 params![
291 session_id,
292 title,
293 agent_id,
294 started_at.to_rfc3339(),
295 now.to_rfc3339(),
296 messages.len() as i64,
297 chunks.len() as i64,
298 workspace,
299 metadata_json,
300 ],
301 )?;
302
303 conn.execute(
305 "DELETE FROM session_chunks WHERE session_id = ?",
306 params![session_id],
307 )?;
308
309 for chunk in &chunks {
311 let mut metadata = HashMap::new();
312 metadata.insert("session_id".to_string(), serde_json::json!(session_id));
313 metadata.insert(
314 "chunk_index".to_string(),
315 serde_json::json!(chunk.chunk_index),
316 );
317 metadata.insert(
318 "start_message".to_string(),
319 serde_json::json!(chunk.start_index),
320 );
321 metadata.insert(
322 "end_message".to_string(),
323 serde_json::json!(chunk.end_index),
324 );
325 metadata.insert(
326 "message_count".to_string(),
327 serde_json::json!(chunk.messages.len()),
328 );
329
330 let input = CreateMemoryInput {
331 content: chunk.content.clone(),
332 memory_type: MemoryType::TranscriptChunk,
333 tags: vec!["transcript".to_string(), format!("session:{}", session_id)],
334 metadata,
335 importance: Some(0.3), scope: Default::default(),
337 workspace: Some(workspace.to_string()),
338 tier: MemoryTier::Daily, defer_embedding: false,
340 ttl_seconds: Some(config.default_ttl_seconds),
341 dedup_mode: Default::default(),
342 dedup_threshold: None,
343 event_time: None,
344 event_duration_seconds: None,
345 trigger_pattern: None,
346 summary_of_id: None,
347 media_url: None,
348 };
349
350 let memory = create_memory(conn, &input)?;
351
352 conn.execute(
354 r#"
355 INSERT INTO session_chunks (session_id, memory_id, chunk_index,
356 start_message_index, end_message_index)
357 VALUES (?, ?, ?, ?, ?)
358 "#,
359 params![
360 session_id,
361 memory.id,
362 chunk.chunk_index as i64,
363 chunk.start_index as i64,
364 chunk.end_index as i64,
365 ],
366 )?;
367 }
368
369 tracing::info!(
370 session_id = session_id,
371 message_count = messages.len(),
372 chunk_count = chunks.len(),
373 "Indexed conversation"
374 );
375
376 Ok(Session {
377 session_id: session_id.to_string(),
378 title: title.map(String::from),
379 agent_id: agent_id.map(String::from),
380 started_at,
381 last_indexed_at: Some(now),
382 message_count: messages.len() as i64,
383 chunk_count: chunks.len() as i64,
384 workspace: workspace.to_string(),
385 metadata: HashMap::new(),
386 })
387}
388
389pub fn index_conversation_delta(
402 conn: &Connection,
403 session_id: &str,
404 new_messages: &[Message],
405 config: &ChunkingConfig,
406) -> Result<Session> {
407 let session: Option<Session> = conn
409 .query_row(
410 "SELECT session_id, title, agent_id, started_at, last_indexed_at,
411 message_count, chunk_count, workspace, metadata
412 FROM sessions WHERE session_id = ?",
413 params![session_id],
414 |row| {
415 let started_at: String = row.get(3)?;
416 let last_indexed_at: Option<String> = row.get(4)?;
417 let metadata_str: String = row.get(8)?;
418 Ok(Session {
419 session_id: row.get(0)?,
420 title: row.get(1)?,
421 agent_id: row.get(2)?,
422 started_at: DateTime::parse_from_rfc3339(&started_at)
423 .map(|dt| dt.with_timezone(&Utc))
424 .unwrap_or_else(|_| Utc::now()),
425 last_indexed_at: last_indexed_at.and_then(|s| {
426 DateTime::parse_from_rfc3339(&s)
427 .map(|dt| dt.with_timezone(&Utc))
428 .ok()
429 }),
430 message_count: row.get(5)?,
431 chunk_count: row.get(6)?,
432 workspace: row.get(7)?,
433 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
434 })
435 },
436 )
437 .ok();
438
439 match session {
440 Some(existing) => {
441 let last_chunk_end: i64 = conn
443 .query_row(
444 "SELECT COALESCE(MAX(end_message_index), 0) FROM session_chunks WHERE session_id = ?",
445 params![session_id],
446 |row| row.get(0),
447 )
448 .unwrap_or(0);
449
450 let overlap_messages: Vec<Message> = existing
452 .metadata
453 .get("overlap_messages")
454 .and_then(|v| serde_json::from_value(v.clone()).ok())
455 .unwrap_or_default();
456
457 let overlap_count = overlap_messages.len();
459
460 let mut all_messages = overlap_messages;
462 all_messages.extend(new_messages.iter().cloned());
463
464 let chunks = chunk_conversation(&all_messages, config);
466
467 if chunks.is_empty() {
468 return Ok(existing);
469 }
470
471 let now = Utc::now();
472 let new_message_count = existing.message_count + new_messages.len() as i64;
473 let starting_chunk_index = existing.chunk_count;
474
475 let new_overlap: Vec<&Message> = all_messages
478 .iter()
479 .rev()
480 .take(config.overlap_messages)
481 .collect::<Vec<_>>()
482 .into_iter()
483 .rev()
484 .collect();
485
486 let mut updated_metadata = existing.metadata.clone();
488 updated_metadata.insert(
489 "overlap_messages".to_string(),
490 serde_json::to_value(&new_overlap).unwrap_or_default(),
491 );
492 let metadata_json = serde_json::to_string(&updated_metadata)?;
493
494 conn.execute(
496 "UPDATE sessions SET last_indexed_at = ?, message_count = ?, chunk_count = ?, metadata = ? WHERE session_id = ?",
497 params![
498 now.to_rfc3339(),
499 new_message_count,
500 existing.chunk_count + chunks.len() as i64,
501 metadata_json,
502 session_id,
503 ],
504 )?;
505
506 let base_offset = (last_chunk_end as usize).saturating_sub(overlap_count);
512
513 for (i, chunk) in chunks.iter().enumerate() {
515 let chunk_index = starting_chunk_index as usize + i;
516
517 let global_start = chunk.start_index + base_offset;
519 let global_end = chunk.end_index + base_offset;
520
521 let mut metadata = HashMap::new();
522 metadata.insert("session_id".to_string(), serde_json::json!(session_id));
523 metadata.insert("chunk_index".to_string(), serde_json::json!(chunk_index));
524 metadata.insert("start_message".to_string(), serde_json::json!(global_start));
525 metadata.insert("end_message".to_string(), serde_json::json!(global_end));
526 metadata.insert(
527 "message_count".to_string(),
528 serde_json::json!(chunk.messages.len()),
529 );
530
531 let input = CreateMemoryInput {
532 content: chunk.content.clone(),
533 memory_type: MemoryType::TranscriptChunk,
534 tags: vec!["transcript".to_string(), format!("session:{}", session_id)],
535 metadata,
536 importance: Some(0.3),
537 scope: Default::default(),
538 workspace: Some(existing.workspace.clone()),
539 tier: MemoryTier::Daily,
540 defer_embedding: false,
541 ttl_seconds: Some(config.default_ttl_seconds),
542 dedup_mode: Default::default(),
543 dedup_threshold: None,
544 event_time: None,
545 event_duration_seconds: None,
546 trigger_pattern: None,
547 summary_of_id: None,
548 media_url: None,
549 };
550
551 let memory = create_memory(conn, &input)?;
552
553 conn.execute(
554 r#"
555 INSERT INTO session_chunks (session_id, memory_id, chunk_index,
556 start_message_index, end_message_index)
557 VALUES (?, ?, ?, ?, ?)
558 "#,
559 params![
560 session_id,
561 memory.id,
562 chunk_index as i64,
563 global_start as i64,
564 global_end as i64,
565 ],
566 )?;
567 }
568
569 Ok(Session {
570 message_count: new_message_count,
571 chunk_count: existing.chunk_count + chunks.len() as i64,
572 last_indexed_at: Some(now),
573 ..existing
574 })
575 }
576 None => {
577 index_conversation(conn, session_id, new_messages, config, None, None, None)
579 }
580 }
581}
582
583pub fn get_session(conn: &Connection, session_id: &str) -> Result<Session> {
585 conn.query_row(
586 "SELECT session_id, title, agent_id, started_at, last_indexed_at,
587 message_count, chunk_count, workspace, metadata
588 FROM sessions WHERE session_id = ?",
589 params![session_id],
590 |row| {
591 let started_at: String = row.get(3)?;
592 let last_indexed_at: Option<String> = row.get(4)?;
593 let metadata_str: String = row.get(8)?;
594 Ok(Session {
595 session_id: row.get(0)?,
596 title: row.get(1)?,
597 agent_id: row.get(2)?,
598 started_at: DateTime::parse_from_rfc3339(&started_at)
599 .map(|dt| dt.with_timezone(&Utc))
600 .unwrap_or_else(|_| Utc::now()),
601 last_indexed_at: last_indexed_at.and_then(|s| {
602 DateTime::parse_from_rfc3339(&s)
603 .map(|dt| dt.with_timezone(&Utc))
604 .ok()
605 }),
606 message_count: row.get(5)?,
607 chunk_count: row.get(6)?,
608 workspace: row.get(7)?,
609 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
610 })
611 },
612 )
613 .map_err(|_| EngramError::NotFound(0))
614}
615
616pub fn list_sessions(
618 conn: &Connection,
619 workspace: Option<&str>,
620 limit: i64,
621) -> Result<Vec<Session>> {
622 let mut sql = String::from(
623 "SELECT session_id, title, agent_id, started_at, last_indexed_at,
624 message_count, chunk_count, workspace, metadata
625 FROM sessions",
626 );
627
628 let mut params: Vec<Box<dyn rusqlite::ToSql>> = vec![];
629
630 if let Some(ws) = workspace {
631 sql.push_str(" WHERE workspace = ?");
632 params.push(Box::new(ws.to_string()));
633 }
634
635 sql.push_str(" ORDER BY started_at DESC LIMIT ?");
636 params.push(Box::new(limit));
637
638 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|b| b.as_ref()).collect();
639 let mut stmt = conn.prepare(&sql)?;
640
641 let sessions = stmt
642 .query_map(param_refs.as_slice(), |row| {
643 let started_at: String = row.get(3)?;
644 let last_indexed_at: Option<String> = row.get(4)?;
645 let metadata_str: String = row.get(8)?;
646 Ok(Session {
647 session_id: row.get(0)?,
648 title: row.get(1)?,
649 agent_id: row.get(2)?,
650 started_at: DateTime::parse_from_rfc3339(&started_at)
651 .map(|dt| dt.with_timezone(&Utc))
652 .unwrap_or_else(|_| Utc::now()),
653 last_indexed_at: last_indexed_at.and_then(|s| {
654 DateTime::parse_from_rfc3339(&s)
655 .map(|dt| dt.with_timezone(&Utc))
656 .ok()
657 }),
658 message_count: row.get(5)?,
659 chunk_count: row.get(6)?,
660 workspace: row.get(7)?,
661 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
662 })
663 })?
664 .filter_map(|r| r.ok())
665 .collect();
666
667 Ok(sessions)
668}
669
670pub fn delete_session(conn: &Connection, session_id: &str) -> Result<()> {
672 conn.execute(
674 r#"
675 UPDATE memories SET valid_to = datetime('now')
676 WHERE id IN (SELECT memory_id FROM session_chunks WHERE session_id = ?)
677 "#,
678 params![session_id],
679 )?;
680
681 conn.execute(
683 "DELETE FROM session_chunks WHERE session_id = ?",
684 params![session_id],
685 )?;
686
687 conn.execute(
689 "DELETE FROM sessions WHERE session_id = ?",
690 params![session_id],
691 )?;
692
693 Ok(())
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699
700 fn make_messages(count: usize, char_len: usize) -> Vec<Message> {
701 (0..count)
702 .map(|i| Message {
703 role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(),
704 content: format!("Message {} {}", i, "x".repeat(char_len)),
705 timestamp: Utc::now(),
706 id: Some(format!("msg-{}", i)),
707 })
708 .collect()
709 }
710
711 #[test]
712 fn test_chunk_by_message_count() {
713 let config = ChunkingConfig {
714 max_messages: 3,
715 overlap_messages: 1,
716 max_chars: 100000, ..Default::default()
718 };
719
720 let messages = make_messages(7, 10);
721 let chunks = chunk_conversation(&messages, &config);
722
723 assert_eq!(chunks.len(), 3);
725 assert_eq!(chunks[0].start_index, 0);
726 assert_eq!(chunks[0].end_index, 3);
727 assert_eq!(chunks[1].start_index, 2); assert_eq!(chunks[1].end_index, 5);
729 assert_eq!(chunks[2].start_index, 4); assert_eq!(chunks[2].end_index, 7);
731 }
732
733 #[test]
734 fn test_chunk_by_char_count() {
735 let config = ChunkingConfig {
736 max_messages: 100, overlap_messages: 1,
738 max_chars: 100, ..Default::default()
740 };
741
742 let messages = make_messages(9, 20);
744 let chunks = chunk_conversation(&messages, &config);
745
746 assert!(chunks.len() > 1);
748 for chunk in &chunks {
749 assert!(chunk.char_count <= config.max_chars + 50); }
751 }
752
753 #[test]
754 fn test_truncate_long_message() {
755 let config = ChunkingConfig {
756 max_messages: 10,
757 overlap_messages: 1,
758 max_chars: 100,
759 ..Default::default()
760 };
761
762 let long_content = "x".repeat(200);
763 let messages = vec![Message {
764 role: "user".to_string(),
765 content: long_content,
766 timestamp: Utc::now(),
767 id: None,
768 }];
769
770 let chunks = chunk_conversation(&messages, &config);
771
772 assert_eq!(chunks.len(), 1);
773 assert!(chunks[0].messages[0].content.contains("[...truncated...]"));
774 assert!(chunks[0].messages[0].content.len() <= config.max_chars);
775 }
776
777 #[test]
778 fn test_empty_conversation() {
779 let config = ChunkingConfig::default();
780 let messages: Vec<Message> = vec![];
781 let chunks = chunk_conversation(&messages, &config);
782
783 assert!(chunks.is_empty());
784 }
785
786 #[test]
787 fn test_format_chunk_content() {
788 let messages = vec![
789 Message {
790 role: "user".to_string(),
791 content: "Hello".to_string(),
792 timestamp: Utc::now(),
793 id: None,
794 },
795 Message {
796 role: "assistant".to_string(),
797 content: "Hi there!".to_string(),
798 timestamp: Utc::now(),
799 id: None,
800 },
801 ];
802
803 let content = format_chunk_content(&messages);
804 assert!(content.contains("[user]: Hello"));
805 assert!(content.contains("[assistant]: Hi there!"));
806 }
807}