1#![allow(dead_code)]
11
12use anyhow::{Context, Result};
31use chrono::Utc;
32use rusqlite::{params, Connection, OptionalExtension};
33use serde::{Deserialize, Serialize};
34use std::path::Path;
35
36pub const SCHEMA_VERSION: &str = "3.0";
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct Workspace {
46 pub id: String,
47 pub name: String,
48 pub path: Option<String>,
49 pub provider: String,
50 pub provider_workspace_id: Option<String>,
51 pub created_at: i64,
52 pub updated_at: i64,
53 pub metadata: Option<String>, }
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Session {
59 pub id: String,
60 pub workspace_id: Option<String>,
61 pub provider: String,
62 pub provider_session_id: Option<String>,
63 pub title: String,
64 pub model: Option<String>,
65 pub message_count: i32,
66 pub token_count: Option<i64>,
67 pub created_at: i64,
68 pub updated_at: i64,
69 pub archived: bool,
70 pub metadata: Option<String>, }
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct Message {
76 pub id: String,
77 pub session_id: String,
78 pub role: String, pub content: String,
80 pub model: Option<String>,
81 pub token_count: Option<i32>,
82 pub created_at: i64,
83 pub parent_id: Option<String>, pub metadata: Option<String>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Checkpoint {
90 pub id: String,
91 pub session_id: String,
92 pub name: String,
93 pub description: Option<String>,
94 pub message_count: i32,
95 pub session_snapshot: String, pub created_at: i64,
97 pub git_commit: Option<String>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ShareLink {
103 pub id: String,
104 pub session_id: Option<String>, pub provider: String,
106 pub url: String,
107 pub share_id: String, pub title: Option<String>,
109 pub imported: bool,
110 pub imported_at: Option<i64>,
111 pub created_at: i64,
112 pub metadata: Option<String>,
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum ShareLinkProvider {
118 ChatGPT,
119 Claude,
120 Gemini,
121 Perplexity,
122 Poe,
123 Other(String),
124}
125
126impl ShareLinkProvider {
127 pub fn from_url(url: &str) -> Option<(Self, String)> {
129 let url_lower = url.to_lowercase();
130
131 if url_lower.contains("chat.openai.com/share/") || url_lower.contains("chatgpt.com/share/")
133 {
134 if let Some(id) = extract_path_segment(url, "share") {
135 return Some((ShareLinkProvider::ChatGPT, id));
136 }
137 }
138
139 if url_lower.contains("claude.ai/share/") {
141 if let Some(id) = extract_path_segment(url, "share") {
142 return Some((ShareLinkProvider::Claude, id));
143 }
144 }
145
146 if url_lower.contains("g.co/gemini/share/")
148 || url_lower.contains("gemini.google.com/share/")
149 {
150 if let Some(id) = extract_path_segment(url, "share") {
151 return Some((ShareLinkProvider::Gemini, id));
152 }
153 }
154
155 if url_lower.contains("perplexity.ai/search/") {
157 if let Some(id) = extract_path_segment(url, "search") {
158 return Some((ShareLinkProvider::Perplexity, id));
159 }
160 }
161
162 if url_lower.contains("poe.com/s/") {
164 if let Some(id) = extract_path_segment(url, "s") {
165 return Some((ShareLinkProvider::Poe, id));
166 }
167 }
168
169 None
170 }
171
172 pub fn name(&self) -> &str {
173 match self {
174 ShareLinkProvider::ChatGPT => "ChatGPT",
175 ShareLinkProvider::Claude => "Claude",
176 ShareLinkProvider::Gemini => "Gemini",
177 ShareLinkProvider::Perplexity => "Perplexity",
178 ShareLinkProvider::Poe => "Poe",
179 ShareLinkProvider::Other(name) => name,
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct ShareLinkInfo {
187 pub provider: String,
188 pub share_id: String,
189}
190
191pub struct ShareLinkParser;
193
194impl ShareLinkParser {
195 pub fn parse(url: &str) -> Option<ShareLinkInfo> {
197 ShareLinkProvider::from_url(url).map(|(provider, share_id)| ShareLinkInfo {
198 provider: provider.name().to_string(),
199 share_id,
200 })
201 }
202}
203
204fn extract_path_segment(url: &str, key: &str) -> Option<String> {
206 let parts: Vec<&str> = url.split('/').collect();
207 for (i, part) in parts.iter().enumerate() {
208 if *part == key && i + 1 < parts.len() {
209 let id = parts[i + 1].split('?').next().unwrap_or(parts[i + 1]);
210 if !id.is_empty() {
211 return Some(id.to_string());
212 }
213 }
214 }
215 None
216}
217
218pub struct ChatDatabase {
224 pub conn: Connection,
225}
226
227impl ChatDatabase {
228 pub fn open(path: &Path) -> Result<Self> {
230 let conn = Connection::open(path).context("Failed to open database")?;
231
232 let db = ChatDatabase { conn };
233 db.initialize()?;
234
235 Ok(db)
236 }
237
238 #[allow(dead_code)]
240 pub fn open_in_memory() -> Result<Self> {
241 let conn = Connection::open_in_memory().context("Failed to create in-memory database")?;
242
243 let db = ChatDatabase { conn };
244 db.initialize()?;
245
246 Ok(db)
247 }
248
249 fn initialize(&self) -> Result<()> {
251 let is_harvest_db = self
254 .conn
255 .query_row("SELECT 1 FROM sessions LIMIT 1", [], |_| Ok(true))
256 .is_ok();
257
258 let has_model_column = self
259 .conn
260 .query_row("SELECT model FROM sessions LIMIT 1", [], |_| Ok(true))
261 .is_ok();
262
263 if !is_harvest_db || has_model_column {
265 self.conn
266 .execute_batch(include_str!("sql/schema.sql"))
267 .context("Failed to initialize database schema")?;
268 }
269 else {
272 self.conn
274 .execute_batch(
275 r#"
276 -- Metadata table for version tracking
277 CREATE TABLE IF NOT EXISTS metadata (
278 key TEXT PRIMARY KEY,
279 value TEXT NOT NULL,
280 updated_at INTEGER DEFAULT (strftime('%s', 'now'))
281 );
282 INSERT OR IGNORE INTO metadata (key, value) VALUES ('schema_version', 'harvest');
283
284 -- Agents table for agent management
285 CREATE TABLE IF NOT EXISTS agents (
286 id TEXT PRIMARY KEY,
287 name TEXT NOT NULL UNIQUE,
288 description TEXT,
289 instruction TEXT NOT NULL,
290 role TEXT DEFAULT 'assistant',
291 model TEXT,
292 provider TEXT,
293 temperature REAL DEFAULT 0.7,
294 max_tokens INTEGER,
295 tools TEXT,
296 sub_agents TEXT,
297 is_active INTEGER DEFAULT 1,
298 created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
299 updated_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
300 metadata TEXT
301 );
302 "#,
303 )
304 .context("Failed to initialize harvest-compatible schema")?;
305 }
306 Ok(())
307 }
308
309 pub fn connection(&self) -> &Connection {
311 &self.conn
312 }
313
314 pub fn get_version(&self) -> Result<String> {
316 let version: String = self
317 .conn
318 .query_row(
319 "SELECT value FROM metadata WHERE key = 'schema_version'",
320 [],
321 |row| row.get(0),
322 )
323 .unwrap_or_else(|_| "unknown".to_string());
324 Ok(version)
325 }
326
327 pub fn upsert_workspace(&self, workspace: &Workspace) -> Result<()> {
333 self.conn.execute(
334 r#"
335 INSERT INTO workspaces (id, name, path, provider, provider_workspace_id,
336 created_at, updated_at, metadata)
337 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
338 ON CONFLICT(id) DO UPDATE SET
339 name = excluded.name,
340 path = excluded.path,
341 updated_at = excluded.updated_at,
342 metadata = excluded.metadata
343 "#,
344 params![
345 workspace.id,
346 workspace.name,
347 workspace.path,
348 workspace.provider,
349 workspace.provider_workspace_id,
350 workspace.created_at,
351 workspace.updated_at,
352 workspace.metadata,
353 ],
354 )?;
355 Ok(())
356 }
357
358 pub fn get_workspace(&self, id: &str) -> Result<Option<Workspace>> {
360 self.conn
361 .query_row(
362 "SELECT id, name, path, provider, provider_workspace_id, created_at, updated_at, metadata
363 FROM workspaces WHERE id = ?",
364 [id],
365 |row| {
366 Ok(Workspace {
367 id: row.get(0)?,
368 name: row.get(1)?,
369 path: row.get(2)?,
370 provider: row.get(3)?,
371 provider_workspace_id: row.get(4)?,
372 created_at: row.get(5)?,
373 updated_at: row.get(6)?,
374 metadata: row.get(7)?,
375 })
376 },
377 )
378 .optional()
379 .context("Failed to get workspace")
380 }
381
382 pub fn list_workspaces(&self) -> Result<Vec<Workspace>> {
384 let mut stmt = self.conn.prepare(
385 "SELECT id, name, path, provider, provider_workspace_id, created_at, updated_at, metadata
386 FROM workspaces ORDER BY updated_at DESC"
387 )?;
388
389 let workspaces = stmt
390 .query_map([], |row| {
391 Ok(Workspace {
392 id: row.get(0)?,
393 name: row.get(1)?,
394 path: row.get(2)?,
395 provider: row.get(3)?,
396 provider_workspace_id: row.get(4)?,
397 created_at: row.get(5)?,
398 updated_at: row.get(6)?,
399 metadata: row.get(7)?,
400 })
401 })?
402 .collect::<Result<Vec<_>, _>>()?;
403
404 Ok(workspaces)
405 }
406
407 pub fn upsert_session(&self, session: &Session) -> Result<bool> {
413 let existing = self.get_session(&session.id)?;
414
415 self.conn.execute(
416 r#"
417 INSERT INTO sessions (id, workspace_id, provider, provider_session_id, title,
418 model, message_count, token_count, created_at, updated_at,
419 archived, metadata)
420 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)
421 ON CONFLICT(id) DO UPDATE SET
422 title = excluded.title,
423 model = excluded.model,
424 message_count = excluded.message_count,
425 token_count = excluded.token_count,
426 updated_at = excluded.updated_at,
427 archived = excluded.archived,
428 metadata = excluded.metadata
429 "#,
430 params![
431 session.id,
432 session.workspace_id,
433 session.provider,
434 session.provider_session_id,
435 session.title,
436 session.model,
437 session.message_count,
438 session.token_count,
439 session.created_at,
440 session.updated_at,
441 session.archived,
442 session.metadata,
443 ],
444 )?;
445
446 Ok(existing.is_some())
447 }
448
449 pub fn get_session(&self, id: &str) -> Result<Option<Session>> {
451 self.conn
452 .query_row(
453 "SELECT id, workspace_id, provider, provider_session_id, title, model,
454 message_count, token_count, created_at, updated_at, archived, metadata
455 FROM sessions WHERE id = ?",
456 [id],
457 |row| {
458 Ok(Session {
459 id: row.get(0)?,
460 workspace_id: row.get(1)?,
461 provider: row.get(2)?,
462 provider_session_id: row.get(3)?,
463 title: row.get(4)?,
464 model: row.get(5)?,
465 message_count: row.get(6)?,
466 token_count: row.get(7)?,
467 created_at: row.get(8)?,
468 updated_at: row.get(9)?,
469 archived: row.get(10)?,
470 metadata: row.get(11)?,
471 })
472 },
473 )
474 .optional()
475 .context("Failed to get session")
476 }
477
478 pub fn list_sessions(
480 &self,
481 workspace_id: Option<&str>,
482 provider: Option<&str>,
483 limit: usize,
484 ) -> Result<Vec<Session>> {
485 let mut query = String::from(
486 "SELECT id, workspace_id, provider, provider_session_id, title, model,
487 message_count, token_count, created_at, updated_at, archived, metadata
488 FROM sessions WHERE 1=1",
489 );
490
491 if workspace_id.is_some() {
492 query.push_str(" AND workspace_id = ?1");
493 }
494 if provider.is_some() {
495 query.push_str(" AND provider = ?2");
496 }
497 query.push_str(" ORDER BY updated_at DESC LIMIT ?3");
498
499 let mut stmt = self.conn.prepare(&query)?;
500
501 let sessions = stmt
502 .query_map(
503 params![
504 workspace_id.unwrap_or(""),
505 provider.unwrap_or(""),
506 limit as i64,
507 ],
508 |row| {
509 Ok(Session {
510 id: row.get(0)?,
511 workspace_id: row.get(1)?,
512 provider: row.get(2)?,
513 provider_session_id: row.get(3)?,
514 title: row.get(4)?,
515 model: row.get(5)?,
516 message_count: row.get(6)?,
517 token_count: row.get(7)?,
518 created_at: row.get(8)?,
519 updated_at: row.get(9)?,
520 archived: row.get(10)?,
521 metadata: row.get(11)?,
522 })
523 },
524 )?
525 .collect::<Result<Vec<_>, _>>()?;
526
527 Ok(sessions)
528 }
529
530 pub fn count_sessions_by_provider(&self) -> Result<Vec<(String, i64)>> {
532 let mut stmt = self.conn.prepare(
533 "SELECT provider, COUNT(*) FROM sessions GROUP BY provider ORDER BY COUNT(*) DESC",
534 )?;
535
536 let counts = stmt
537 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
538 .collect::<Result<Vec<_>, _>>()?;
539
540 Ok(counts)
541 }
542
543 pub fn insert_message(&self, message: &Message) -> Result<()> {
549 self.conn.execute(
550 r#"
551 INSERT INTO messages (id, session_id, role, content, model, token_count,
552 created_at, parent_id, metadata)
553 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
554 ON CONFLICT(id) DO UPDATE SET
555 content = excluded.content,
556 metadata = excluded.metadata
557 "#,
558 params![
559 message.id,
560 message.session_id,
561 message.role,
562 message.content,
563 message.model,
564 message.token_count,
565 message.created_at,
566 message.parent_id,
567 message.metadata,
568 ],
569 )?;
570 Ok(())
571 }
572
573 pub fn get_messages(&self, session_id: &str) -> Result<Vec<Message>> {
575 let mut stmt = self.conn.prepare(
576 "SELECT id, session_id, role, content, model, token_count, created_at, parent_id, metadata
577 FROM messages WHERE session_id = ? ORDER BY created_at ASC"
578 )?;
579
580 let messages = stmt
581 .query_map([session_id], |row| {
582 Ok(Message {
583 id: row.get(0)?,
584 session_id: row.get(1)?,
585 role: row.get(2)?,
586 content: row.get(3)?,
587 model: row.get(4)?,
588 token_count: row.get(5)?,
589 created_at: row.get(6)?,
590 parent_id: row.get(7)?,
591 metadata: row.get(8)?,
592 })
593 })?
594 .collect::<Result<Vec<_>, _>>()?;
595
596 Ok(messages)
597 }
598
599 pub fn create_checkpoint(
605 &self,
606 session_id: &str,
607 name: &str,
608 description: Option<&str>,
609 git_commit: Option<&str>,
610 ) -> Result<Checkpoint> {
611 let session = self.get_session(session_id)?.context("Session not found")?;
612
613 let messages = self.get_messages(session_id)?;
614
615 let snapshot = serde_json::json!({
616 "session": session,
617 "messages": messages,
618 });
619
620 let checkpoint = Checkpoint {
621 id: uuid::Uuid::new_v4().to_string(),
622 session_id: session_id.to_string(),
623 name: name.to_string(),
624 description: description.map(String::from),
625 message_count: messages.len() as i32,
626 session_snapshot: serde_json::to_string(&snapshot)?,
627 created_at: Utc::now().timestamp_millis(),
628 git_commit: git_commit.map(String::from),
629 };
630
631 self.conn.execute(
632 r#"
633 INSERT INTO checkpoints (id, session_id, name, description, message_count,
634 session_snapshot, created_at, git_commit)
635 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
636 "#,
637 params![
638 checkpoint.id,
639 checkpoint.session_id,
640 checkpoint.name,
641 checkpoint.description,
642 checkpoint.message_count,
643 checkpoint.session_snapshot,
644 checkpoint.created_at,
645 checkpoint.git_commit,
646 ],
647 )?;
648
649 Ok(checkpoint)
650 }
651
652 pub fn list_checkpoints(&self, session_id: &str) -> Result<Vec<Checkpoint>> {
654 let mut stmt = self.conn.prepare(
655 "SELECT id, session_id, name, description, message_count, session_snapshot, created_at, git_commit
656 FROM checkpoints WHERE session_id = ? ORDER BY created_at DESC"
657 )?;
658
659 let checkpoints = stmt
660 .query_map([session_id], |row| {
661 Ok(Checkpoint {
662 id: row.get(0)?,
663 session_id: row.get(1)?,
664 name: row.get(2)?,
665 description: row.get(3)?,
666 message_count: row.get(4)?,
667 session_snapshot: row.get(5)?,
668 created_at: row.get(6)?,
669 git_commit: row.get(7)?,
670 })
671 })?
672 .collect::<Result<Vec<_>, _>>()?;
673
674 Ok(checkpoints)
675 }
676
677 pub fn add_share_link(&self, link: &ShareLink) -> Result<()> {
683 self.conn.execute(
684 r#"
685 INSERT INTO share_links (id, session_id, provider, url, share_id, title,
686 imported, imported_at, created_at, metadata)
687 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
688 ON CONFLICT(url) DO UPDATE SET
689 session_id = COALESCE(excluded.session_id, share_links.session_id),
690 imported = excluded.imported,
691 imported_at = excluded.imported_at,
692 metadata = excluded.metadata
693 "#,
694 params![
695 link.id,
696 link.session_id,
697 link.provider,
698 link.url,
699 link.share_id,
700 link.title,
701 link.imported,
702 link.imported_at,
703 link.created_at,
704 link.metadata,
705 ],
706 )?;
707 Ok(())
708 }
709
710 pub fn get_share_link_by_url(&self, url: &str) -> Result<Option<ShareLink>> {
712 self.conn
713 .query_row(
714 "SELECT id, session_id, provider, url, share_id, title, imported, imported_at, created_at, metadata
715 FROM share_links WHERE url = ?",
716 [url],
717 |row| {
718 Ok(ShareLink {
719 id: row.get(0)?,
720 session_id: row.get(1)?,
721 provider: row.get(2)?,
722 url: row.get(3)?,
723 share_id: row.get(4)?,
724 title: row.get(5)?,
725 imported: row.get(6)?,
726 imported_at: row.get(7)?,
727 created_at: row.get(8)?,
728 metadata: row.get(9)?,
729 })
730 },
731 )
732 .optional()
733 .context("Failed to get share link")
734 }
735
736 pub fn list_share_links(&self, imported_only: bool) -> Result<Vec<ShareLink>> {
738 let query = if imported_only {
739 "SELECT id, session_id, provider, url, share_id, title, imported, imported_at, created_at, metadata
740 FROM share_links WHERE imported = 1 ORDER BY created_at DESC"
741 } else {
742 "SELECT id, session_id, provider, url, share_id, title, imported, imported_at, created_at, metadata
743 FROM share_links ORDER BY created_at DESC"
744 };
745
746 let mut stmt = self.conn.prepare(query)?;
747
748 let links = stmt
749 .query_map([], |row| {
750 Ok(ShareLink {
751 id: row.get(0)?,
752 session_id: row.get(1)?,
753 provider: row.get(2)?,
754 url: row.get(3)?,
755 share_id: row.get(4)?,
756 title: row.get(5)?,
757 imported: row.get(6)?,
758 imported_at: row.get(7)?,
759 created_at: row.get(8)?,
760 metadata: row.get(9)?,
761 })
762 })?
763 .collect::<Result<Vec<_>, _>>()?;
764
765 Ok(links)
766 }
767
768 pub fn mark_share_link_imported(&self, url: &str, session_id: &str) -> Result<()> {
770 let now = Utc::now().timestamp_millis();
771 self.conn.execute(
772 "UPDATE share_links SET imported = 1, imported_at = ?, session_id = ? WHERE url = ?",
773 params![now, session_id, url],
774 )?;
775 Ok(())
776 }
777
778 pub fn get_statistics(&self) -> Result<DatabaseStats> {
784 let workspace_count: i64 =
785 self.conn
786 .query_row("SELECT COUNT(*) FROM workspaces", [], |row| row.get(0))?;
787
788 let session_count: i64 =
789 self.conn
790 .query_row("SELECT COUNT(*) FROM sessions", [], |row| row.get(0))?;
791
792 let message_count: i64 =
793 self.conn
794 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
795
796 let checkpoint_count: i64 =
797 self.conn
798 .query_row("SELECT COUNT(*) FROM checkpoints", [], |row| row.get(0))?;
799
800 let share_link_count: i64 =
801 self.conn
802 .query_row("SELECT COUNT(*) FROM share_links", [], |row| row.get(0))?;
803
804 let imported_link_count: i64 = self.conn.query_row(
805 "SELECT COUNT(*) FROM share_links WHERE imported = 1",
806 [],
807 |row| row.get(0),
808 )?;
809
810 Ok(DatabaseStats {
811 workspace_count,
812 session_count,
813 message_count,
814 checkpoint_count,
815 share_link_count,
816 imported_link_count,
817 })
818 }
819}
820
821#[derive(Debug, Clone, Serialize, Deserialize)]
823pub struct DatabaseStats {
824 pub workspace_count: i64,
825 pub session_count: i64,
826 pub message_count: i64,
827 pub checkpoint_count: i64,
828 pub share_link_count: i64,
829 pub imported_link_count: i64,
830}
831
832#[cfg(test)]
837mod tests {
838 use super::*;
839
840 #[test]
841 fn test_share_link_provider_parsing() {
842 let (provider, id) =
844 ShareLinkProvider::from_url("https://chat.openai.com/share/abc123").unwrap();
845 assert_eq!(provider, ShareLinkProvider::ChatGPT);
846 assert_eq!(id, "abc123");
847
848 let (provider, id) = ShareLinkProvider::from_url("https://claude.ai/share/xyz789").unwrap();
850 assert_eq!(provider, ShareLinkProvider::Claude);
851 assert_eq!(id, "xyz789");
852
853 let (provider, id) =
855 ShareLinkProvider::from_url("https://www.perplexity.ai/search/test-query-123").unwrap();
856 assert_eq!(provider, ShareLinkProvider::Perplexity);
857 assert_eq!(id, "test-query-123");
858
859 assert!(ShareLinkProvider::from_url("https://example.com/test").is_none());
861 }
862
863 #[test]
864 fn test_database_creation() {
865 let db = ChatDatabase::open_in_memory().unwrap();
866 let version = db.get_version().unwrap();
867 assert_eq!(version, SCHEMA_VERSION);
868 }
869
870 #[test]
871 fn test_workspace_crud() {
872 let db = ChatDatabase::open_in_memory().unwrap();
873
874 let workspace = Workspace {
875 id: "ws-1".to_string(),
876 name: "Test Workspace".to_string(),
877 path: Some("/test/path".to_string()),
878 provider: "vscode".to_string(),
879 provider_workspace_id: Some("hash123".to_string()),
880 created_at: 1000,
881 updated_at: 2000,
882 metadata: None,
883 };
884
885 db.upsert_workspace(&workspace).unwrap();
886
887 let retrieved = db.get_workspace("ws-1").unwrap().unwrap();
888 assert_eq!(retrieved.name, "Test Workspace");
889
890 let workspaces = db.list_workspaces().unwrap();
891 assert_eq!(workspaces.len(), 1);
892 }
893
894 #[test]
895 fn test_session_crud() {
896 let db = ChatDatabase::open_in_memory().unwrap();
897
898 let session = Session {
899 id: "sess-1".to_string(),
900 workspace_id: None,
901 provider: "chatgpt".to_string(),
902 provider_session_id: Some("gpt-abc".to_string()),
903 title: "Test Session".to_string(),
904 model: Some("gpt-4".to_string()),
905 message_count: 5,
906 token_count: Some(1000),
907 created_at: 1000,
908 updated_at: 2000,
909 archived: false,
910 metadata: None,
911 };
912
913 let was_update = db.upsert_session(&session).unwrap();
914 assert!(!was_update);
915
916 let retrieved = db.get_session("sess-1").unwrap().unwrap();
917 assert_eq!(retrieved.title, "Test Session");
918 }
919
920 #[test]
921 fn test_share_link_operations() {
922 let db = ChatDatabase::open_in_memory().unwrap();
923
924 let link = ShareLink {
925 id: "link-1".to_string(),
926 session_id: None,
927 provider: "ChatGPT".to_string(),
928 url: "https://chat.openai.com/share/abc123".to_string(),
929 share_id: "abc123".to_string(),
930 title: Some("Shared Chat".to_string()),
931 imported: false,
932 imported_at: None,
933 created_at: 1000,
934 metadata: None,
935 };
936
937 db.add_share_link(&link).unwrap();
938
939 let retrieved = db
940 .get_share_link_by_url("https://chat.openai.com/share/abc123")
941 .unwrap()
942 .unwrap();
943 assert_eq!(retrieved.share_id, "abc123");
944 assert!(!retrieved.imported);
945
946 let session = Session {
948 id: "sess-123".to_string(),
949 workspace_id: None,
950 provider: "chatgpt".to_string(),
951 provider_session_id: None,
952 title: "Imported Session".to_string(),
953 model: None,
954 message_count: 0,
955 token_count: None,
956 created_at: 1000,
957 updated_at: 1000,
958 archived: false,
959 metadata: None,
960 };
961 db.upsert_session(&session).unwrap();
962
963 db.mark_share_link_imported(&link.url, "sess-123").unwrap();
965
966 let updated = db.get_share_link_by_url(&link.url).unwrap().unwrap();
967 assert!(updated.imported);
968 assert_eq!(updated.session_id, Some("sess-123".to_string()));
969 }
970
971 #[test]
972 fn test_checkpoint_creation() {
973 let db = ChatDatabase::open_in_memory().unwrap();
974
975 let session = Session {
977 id: "sess-1".to_string(),
978 workspace_id: None,
979 provider: "test".to_string(),
980 provider_session_id: None,
981 title: "Test".to_string(),
982 model: None,
983 message_count: 0,
984 token_count: None,
985 created_at: 1000,
986 updated_at: 1000,
987 archived: false,
988 metadata: None,
989 };
990 db.upsert_session(&session).unwrap();
991
992 let checkpoint = db
994 .create_checkpoint("sess-1", "v1.0", Some("First checkpoint"), None)
995 .unwrap();
996 assert_eq!(checkpoint.name, "v1.0");
997
998 let checkpoints = db.list_checkpoints("sess-1").unwrap();
1000 assert_eq!(checkpoints.len(), 1);
1001 }
1002
1003 #[test]
1004 fn test_database_statistics() {
1005 let db = ChatDatabase::open_in_memory().unwrap();
1006
1007 let stats = db.get_statistics().unwrap();
1008 assert_eq!(stats.workspace_count, 0);
1009 assert_eq!(stats.session_count, 0);
1010 }
1011}