1use super::types::{
2 BufferedEvent, ChatEvent, ClaudeCodeConfig, ClaudeCodeError, ClaudeCodeHandle, ClaudeCodeId,
3 ClaudeCodeInfo, Message, MessageId, MessageRole, Model, Position, StreamId,
4 StreamInfo, StreamStatus,
5};
6use crate::activations::arbor::{ArborStorage, NodeId, TreeId};
7use serde_json::Value;
8use sqlx::{sqlite::{SqliteConnectOptions, SqlitePool}, ConnectOptions, Row};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13use tokio::sync::RwLock;
14use uuid::Uuid;
15
16#[derive(Debug, Clone)]
18pub struct ClaudeCodeStorageConfig {
19 pub db_path: PathBuf,
21}
22
23impl Default for ClaudeCodeStorageConfig {
24 fn default() -> Self {
25 Self {
26 db_path: PathBuf::from("claudecode.db"),
27 }
28 }
29}
30
31#[derive(Debug)]
33struct ActiveStreamBuffer {
34 info: StreamInfo,
36 events: Vec<BufferedEvent>,
38}
39
40pub struct ClaudeCodeStorage {
42 pool: SqlitePool,
43 arbor: Arc<ArborStorage>,
44 streams: RwLock<HashMap<StreamId, ActiveStreamBuffer>>,
46}
47
48impl ClaudeCodeStorage {
49 pub async fn new(
51 config: ClaudeCodeStorageConfig,
52 arbor: Arc<ArborStorage>,
53 ) -> Result<Self, ClaudeCodeError> {
54 let db_url = format!("sqlite:{}?mode=rwc", config.db_path.display());
55 let connect_options: SqliteConnectOptions = db_url.parse()
56 .map_err(|e| format!("Failed to parse database URL: {}", e))?;
57 let connect_options = connect_options.disable_statement_logging();
58 let pool = SqlitePool::connect_with(connect_options.clone())
59 .await
60 .map_err(|e| format!("Failed to connect to claudecode database: {}", e))?;
61
62 let storage = Self {
63 pool,
64 arbor,
65 streams: RwLock::new(HashMap::new()),
66 };
67 storage.run_migrations().await?;
68
69 Ok(storage)
70 }
71
72 async fn run_migrations(&self) -> Result<(), ClaudeCodeError> {
74 sqlx::query(
75 r#"
76 CREATE TABLE IF NOT EXISTS claudecode_sessions (
77 id TEXT PRIMARY KEY,
78 name TEXT NOT NULL UNIQUE,
79 claude_session_id TEXT,
80 tree_id TEXT NOT NULL,
81 canonical_head TEXT NOT NULL,
82 working_dir TEXT NOT NULL,
83 model TEXT NOT NULL,
84 system_prompt TEXT,
85 mcp_config TEXT,
86 loopback_enabled INTEGER NOT NULL DEFAULT 0,
87 metadata TEXT,
88 created_at INTEGER NOT NULL,
89 updated_at INTEGER NOT NULL
90 );
91
92 CREATE TABLE IF NOT EXISTS claudecode_messages (
93 id TEXT PRIMARY KEY,
94 session_id TEXT NOT NULL,
95 role TEXT NOT NULL,
96 content TEXT NOT NULL,
97 model_id TEXT,
98 input_tokens INTEGER,
99 output_tokens INTEGER,
100 cost_usd REAL,
101 created_at INTEGER NOT NULL,
102 FOREIGN KEY (session_id) REFERENCES claudecode_sessions(id) ON DELETE CASCADE
103 );
104
105 CREATE INDEX IF NOT EXISTS idx_claudecode_sessions_name ON claudecode_sessions(name);
106 CREATE INDEX IF NOT EXISTS idx_claudecode_sessions_tree ON claudecode_sessions(tree_id);
107 CREATE INDEX IF NOT EXISTS idx_claudecode_messages_session ON claudecode_messages(session_id);
108
109 CREATE TABLE IF NOT EXISTS claudecode_unknown_events (
110 id TEXT PRIMARY KEY,
111 session_id TEXT,
112 event_type TEXT NOT NULL,
113 data TEXT NOT NULL,
114 created_at INTEGER NOT NULL,
115 FOREIGN KEY (session_id) REFERENCES claudecode_sessions(id) ON DELETE CASCADE
116 );
117
118 CREATE INDEX IF NOT EXISTS idx_claudecode_unknown_events_session ON claudecode_unknown_events(session_id);
119 CREATE INDEX IF NOT EXISTS idx_claudecode_unknown_events_type ON claudecode_unknown_events(event_type);
120 "#,
121 )
122 .execute(&self.pool)
123 .await
124 .map_err(|e| format!("Failed to run claudecode migrations: {}", e))?;
125
126 Ok(())
127 }
128
129 pub fn arbor(&self) -> &ArborStorage {
131 &self.arbor
132 }
133
134 pub async fn session_create(
140 &self,
141 name: String,
142 working_dir: String,
143 model: Model,
144 system_prompt: Option<String>,
145 mcp_config: Option<Value>,
146 loopback_enabled: bool,
147 metadata: Option<Value>,
148 ) -> Result<ClaudeCodeConfig, ClaudeCodeError> {
149 let session_id = ClaudeCodeId::new_v4();
150 let now = current_timestamp();
151
152 let tree_id = self
154 .arbor
155 .tree_create(metadata.clone(), &session_id.to_string())
156 .await
157 .map_err(|e| format!("Failed to create tree for session: {}", e))?;
158
159 let tree = self
161 .arbor
162 .tree_get(&tree_id)
163 .await
164 .map_err(|e| format!("Failed to get tree: {}", e))?;
165 let head = Position::new(tree_id, tree.root);
166
167 let metadata_json = metadata.as_ref().map(|m| serde_json::to_string(m).unwrap());
168 let mcp_config_json = mcp_config.as_ref().map(|m| serde_json::to_string(m).unwrap());
169
170 let final_name = match sqlx::query(
172 "INSERT INTO claudecode_sessions (id, name, claude_session_id, tree_id, canonical_head, working_dir, model, system_prompt, mcp_config, loopback_enabled, metadata, created_at, updated_at)
173 VALUES (?, ?, NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
174 )
175 .bind(session_id.to_string())
176 .bind(&name)
177 .bind(head.tree_id.to_string())
178 .bind(head.node_id.to_string())
179 .bind(&working_dir)
180 .bind(model.as_str())
181 .bind(&system_prompt)
182 .bind(mcp_config_json.clone())
183 .bind(loopback_enabled)
184 .bind(metadata_json.clone())
185 .bind(now)
186 .bind(now)
187 .execute(&self.pool)
188 .await
189 {
190 Ok(_) => name,
191 Err(e) if e.to_string().contains("UNIQUE constraint failed") => {
192 let unique_name = format!("{}#{}", name, session_id);
194
195 sqlx::query(
196 "INSERT INTO claudecode_sessions (id, name, claude_session_id, tree_id, canonical_head, working_dir, model, system_prompt, mcp_config, loopback_enabled, metadata, created_at, updated_at)
197 VALUES (?, ?, NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
198 )
199 .bind(session_id.to_string())
200 .bind(&unique_name)
201 .bind(head.tree_id.to_string())
202 .bind(head.node_id.to_string())
203 .bind(&working_dir)
204 .bind(model.as_str())
205 .bind(&system_prompt)
206 .bind(mcp_config_json)
207 .bind(loopback_enabled)
208 .bind(metadata_json)
209 .bind(now)
210 .bind(now)
211 .execute(&self.pool)
212 .await
213 .map_err(|e| format!("Failed to create session with unique name: {}", e))?;
214
215 unique_name
216 }
217 Err(e) => return Err(ClaudeCodeError::from(format!("Failed to create session: {}", e))),
218 };
219
220 Ok(ClaudeCodeConfig {
221 id: session_id,
222 name: final_name,
223 claude_session_id: None,
224 head,
225 working_dir,
226 model,
227 system_prompt,
228 mcp_config,
229 loopback_enabled,
230 metadata,
231 created_at: now,
232 updated_at: now,
233 })
234 }
235
236 pub async fn session_get(&self, session_id: &ClaudeCodeId) -> Result<ClaudeCodeConfig, ClaudeCodeError> {
238 let row = sqlx::query(
239 "SELECT id, name, claude_session_id, tree_id, canonical_head, working_dir, model, system_prompt, mcp_config, loopback_enabled, metadata, created_at, updated_at
240 FROM claudecode_sessions WHERE id = ?",
241 )
242 .bind(session_id.to_string())
243 .fetch_optional(&self.pool)
244 .await
245 .map_err(|e| format!("Failed to fetch session: {}", e))?
246 .ok_or_else(|| format!("Session not found: {}", session_id))?;
247
248 self.row_to_config(row)
249 }
250
251 pub async fn session_get_by_name(&self, name: &str) -> Result<ClaudeCodeConfig, ClaudeCodeError> {
253 if let Some(row) = sqlx::query(
255 "SELECT id, name, claude_session_id, tree_id, canonical_head, working_dir, model, system_prompt, mcp_config, loopback_enabled, metadata, created_at, updated_at
256 FROM claudecode_sessions WHERE name = ?",
257 )
258 .bind(name)
259 .fetch_optional(&self.pool)
260 .await
261 .map_err(|e| format!("Failed to fetch session by name: {}", e))?
262 {
263 return self.row_to_config(row);
264 }
265
266 let pattern = format!("{}%", name);
268 let rows = sqlx::query(
269 "SELECT id, name, claude_session_id, tree_id, canonical_head, working_dir, model, system_prompt, mcp_config, loopback_enabled, metadata, created_at, updated_at
270 FROM claudecode_sessions WHERE name LIKE ?",
271 )
272 .bind(&pattern)
273 .fetch_all(&self.pool)
274 .await
275 .map_err(|e| format!("Failed to fetch session by pattern: {}", e))?;
276
277 match rows.len() {
278 0 => Err(ClaudeCodeError::from(format!("Session not found with name: {}", name))),
279 1 => self.row_to_config(rows.into_iter().next().unwrap()),
280 _ => {
281 let matches: Vec<String> = rows.iter().map(|r| r.get("name")).collect();
282 Err(ClaudeCodeError::from(format!(
283 "Ambiguous name '{}' matches multiple sessions: {}",
284 name,
285 matches.join(", ")
286 )))
287 }
288 }
289 }
290
291 pub async fn session_list(&self) -> Result<Vec<ClaudeCodeInfo>, ClaudeCodeError> {
293 let rows = sqlx::query(
294 "SELECT id, name, claude_session_id, tree_id, canonical_head, working_dir, model, loopback_enabled, created_at
295 FROM claudecode_sessions ORDER BY created_at DESC",
296 )
297 .fetch_all(&self.pool)
298 .await
299 .map_err(|e| format!("Failed to list sessions: {}", e))?;
300
301 let sessions: Result<Vec<ClaudeCodeInfo>, ClaudeCodeError> = rows
302 .iter()
303 .map(|row| {
304 let id_str: String = row.get("id");
305 let tree_id_str: String = row.get("tree_id");
306 let head_str: String = row.get("canonical_head");
307 let model_str: String = row.get("model");
308 let loopback: i32 = row.get("loopback_enabled");
309
310 let tree_id = TreeId::parse_str(&tree_id_str)
311 .map_err(|e| format!("Invalid tree ID: {}", e))?;
312 let node_id = NodeId::parse_str(&head_str)
313 .map_err(|e| format!("Invalid node ID: {}", e))?;
314 let model = Model::from_str(&model_str)
315 .ok_or_else(|| format!("Invalid model: {}", model_str))?;
316
317 Ok(ClaudeCodeInfo {
318 id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid session ID: {}", e))?,
319 name: row.get("name"),
320 model,
321 head: Position::new(tree_id, node_id),
322 claude_session_id: row.get("claude_session_id"),
323 working_dir: row.get("working_dir"),
324 loopback_enabled: loopback != 0,
325 created_at: row.get("created_at"),
326 })
327 })
328 .collect();
329
330 sessions
331 }
332
333 pub async fn session_update_head(
335 &self,
336 session_id: &ClaudeCodeId,
337 new_head: NodeId,
338 claude_session_id: Option<String>,
339 ) -> Result<(), ClaudeCodeError> {
340 let now = current_timestamp();
341
342 let result = if let Some(claude_id) = claude_session_id {
343 sqlx::query(
344 "UPDATE claudecode_sessions SET canonical_head = ?, claude_session_id = ?, updated_at = ? WHERE id = ?",
345 )
346 .bind(new_head.to_string())
347 .bind(claude_id)
348 .bind(now)
349 .bind(session_id.to_string())
350 .execute(&self.pool)
351 .await
352 } else {
353 sqlx::query(
354 "UPDATE claudecode_sessions SET canonical_head = ?, updated_at = ? WHERE id = ?",
355 )
356 .bind(new_head.to_string())
357 .bind(now)
358 .bind(session_id.to_string())
359 .execute(&self.pool)
360 .await
361 }
362 .map_err(|e| format!("Failed to update session head: {}", e))?;
363
364 if result.rows_affected() == 0 {
365 return Err(format!("Session not found: {}", session_id).into());
366 }
367
368 Ok(())
369 }
370
371 pub async fn session_update(
373 &self,
374 session_id: &ClaudeCodeId,
375 name: Option<String>,
376 model: Option<Model>,
377 system_prompt: Option<Option<String>>,
378 mcp_config: Option<Value>,
379 metadata: Option<Value>,
380 ) -> Result<(), ClaudeCodeError> {
381 let now = current_timestamp();
382 let current = self.session_get(session_id).await?;
383
384 let new_name = name.unwrap_or(current.name);
385 let new_model = model.unwrap_or(current.model);
386 let new_prompt = system_prompt.unwrap_or(current.system_prompt);
387 let new_mcp = mcp_config.or(current.mcp_config);
388 let new_metadata = metadata.or(current.metadata);
389
390 let mcp_json = new_mcp.as_ref().map(|m| serde_json::to_string(m).unwrap());
391 let metadata_json = new_metadata.as_ref().map(|m| serde_json::to_string(m).unwrap());
392
393 sqlx::query(
394 "UPDATE claudecode_sessions SET name = ?, model = ?, system_prompt = ?, mcp_config = ?, metadata = ?, updated_at = ? WHERE id = ?",
395 )
396 .bind(&new_name)
397 .bind(new_model.as_str())
398 .bind(&new_prompt)
399 .bind(mcp_json)
400 .bind(metadata_json)
401 .bind(now)
402 .bind(session_id.to_string())
403 .execute(&self.pool)
404 .await
405 .map_err(|e| format!("Failed to update session: {}", e))?;
406
407 Ok(())
408 }
409
410 pub async fn session_delete(&self, session_id: &ClaudeCodeId) -> Result<(), ClaudeCodeError> {
412 let result = sqlx::query("DELETE FROM claudecode_sessions WHERE id = ?")
413 .bind(session_id.to_string())
414 .execute(&self.pool)
415 .await
416 .map_err(|e| format!("Failed to delete session: {}", e))?;
417
418 if result.rows_affected() == 0 {
419 return Err(format!("Session not found: {}", session_id).into());
420 }
421
422 Ok(())
423 }
424
425 pub async fn message_create(
431 &self,
432 session_id: &ClaudeCodeId,
433 role: MessageRole,
434 content: String,
435 model_id: Option<String>,
436 input_tokens: Option<i64>,
437 output_tokens: Option<i64>,
438 cost_usd: Option<f64>,
439 ) -> Result<Message, ClaudeCodeError> {
440 let message_id = MessageId::new_v4();
441 let now = current_timestamp();
442
443 sqlx::query(
444 "INSERT INTO claudecode_messages (id, session_id, role, content, model_id, input_tokens, output_tokens, cost_usd, created_at)
445 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
446 )
447 .bind(message_id.to_string())
448 .bind(session_id.to_string())
449 .bind(role.as_str())
450 .bind(&content)
451 .bind(&model_id)
452 .bind(input_tokens)
453 .bind(output_tokens)
454 .bind(cost_usd)
455 .bind(now)
456 .execute(&self.pool)
457 .await
458 .map_err(|e| format!("Failed to create message: {}", e))?;
459
460 Ok(Message {
461 id: message_id,
462 session_id: *session_id,
463 role,
464 content,
465 created_at: now,
466 model_id,
467 input_tokens,
468 output_tokens,
469 cost_usd,
470 })
471 }
472
473 pub async fn message_create_ephemeral(
475 &self,
476 session_id: &ClaudeCodeId,
477 role: MessageRole,
478 content: String,
479 model_id: Option<String>,
480 input_tokens: Option<i64>,
481 output_tokens: Option<i64>,
482 cost_usd: Option<f64>,
483 ) -> Result<Message, ClaudeCodeError> {
484 let message_id = MessageId::new_v4();
485 let now = current_timestamp();
486
487 let ephemeral_marker = -now;
491
492 sqlx::query(
493 "INSERT INTO claudecode_messages (id, session_id, role, content, model_id, input_tokens, output_tokens, cost_usd, created_at)
494 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
495 )
496 .bind(message_id.to_string())
497 .bind(session_id.to_string())
498 .bind(role.as_str())
499 .bind(&content)
500 .bind(&model_id)
501 .bind(input_tokens)
502 .bind(output_tokens)
503 .bind(cost_usd)
504 .bind(ephemeral_marker)
505 .execute(&self.pool)
506 .await
507 .map_err(|e| format!("Failed to create ephemeral message: {}", e))?;
508
509 Ok(Message {
510 id: message_id,
511 session_id: *session_id,
512 role,
513 content,
514 created_at: ephemeral_marker,
515 model_id,
516 input_tokens,
517 output_tokens,
518 cost_usd,
519 })
520 }
521
522 pub async fn message_get(&self, message_id: &MessageId) -> Result<Message, ClaudeCodeError> {
524 let row = sqlx::query(
525 "SELECT id, session_id, role, content, model_id, input_tokens, output_tokens, cost_usd, created_at
526 FROM claudecode_messages WHERE id = ?",
527 )
528 .bind(message_id.to_string())
529 .fetch_optional(&self.pool)
530 .await
531 .map_err(|e| format!("Failed to fetch message: {}", e))?
532 .ok_or_else(|| format!("Message not found: {}", message_id))?;
533
534 self.row_to_message(row)
535 }
536
537 pub async fn resolve_message_handle(&self, identifier: &str) -> Result<Message, ClaudeCodeError> {
540 let parts: Vec<&str> = identifier.splitn(3, ':').collect();
541 if parts.len() < 2 {
542 return Err(format!("Invalid message handle format: {}", identifier).into());
543 }
544
545 let msg_part = parts[0];
546 if !msg_part.starts_with("msg-") {
547 return Err(format!("Invalid message handle format: {}", identifier).into());
548 }
549
550 let message_id_str = &msg_part[4..];
551 let message_id = Uuid::parse_str(message_id_str)
552 .map_err(|e| format!("Invalid message ID in handle: {}", e))?;
553
554 self.message_get(&message_id).await
555 }
556
557 pub fn message_to_handle(message: &Message, name: &str) -> crate::types::Handle {
562 ClaudeCodeHandle::Message {
563 message_id: format!("msg-{}", message.id),
564 role: message.role.as_str().to_string(),
565 name: name.to_string(),
566 }.to_handle()
567 }
568
569 pub async fn unknown_event_store(
575 &self,
576 session_id: Option<&ClaudeCodeId>,
577 event_type: &str,
578 data: &Value,
579 ) -> Result<String, ClaudeCodeError> {
580 let id = Uuid::new_v4().to_string();
581 let now = current_timestamp();
582 let data_json = serde_json::to_string(data)
583 .map_err(|e| format!("Failed to serialize unknown event data: {}", e))?;
584
585 sqlx::query(
586 "INSERT INTO claudecode_unknown_events (id, session_id, event_type, data, created_at)
587 VALUES (?, ?, ?, ?, ?)",
588 )
589 .bind(&id)
590 .bind(session_id.map(|s| s.to_string()))
591 .bind(event_type)
592 .bind(&data_json)
593 .bind(now)
594 .execute(&self.pool)
595 .await
596 .map_err(|e| format!("Failed to store unknown event: {}", e))?;
597
598 Ok(id)
599 }
600
601 pub async fn unknown_event_get(&self, id: &str) -> Result<(String, Value), ClaudeCodeError> {
603 let row = sqlx::query(
604 "SELECT event_type, data FROM claudecode_unknown_events WHERE id = ?",
605 )
606 .bind(id)
607 .fetch_optional(&self.pool)
608 .await
609 .map_err(|e| format!("Failed to fetch unknown event: {}", e))?
610 .ok_or_else(|| format!("Unknown event not found: {}", id))?;
611
612 let event_type: String = row.get("event_type");
613 let data_json: String = row.get("data");
614 let data: Value = serde_json::from_str(&data_json)
615 .map_err(|e| format!("Failed to parse unknown event data: {}", e))?;
616
617 Ok((event_type, data))
618 }
619
620 pub async fn unknown_events_by_type(&self, event_type: &str) -> Result<Vec<(String, Value)>, ClaudeCodeError> {
622 let rows = sqlx::query(
623 "SELECT id, data FROM claudecode_unknown_events WHERE event_type = ? ORDER BY created_at DESC",
624 )
625 .bind(event_type)
626 .fetch_all(&self.pool)
627 .await
628 .map_err(|e| format!("Failed to list unknown events: {}", e))?;
629
630 rows.iter()
631 .map(|row| {
632 let id: String = row.get("id");
633 let data_json: String = row.get("data");
634 let data: Value = serde_json::from_str(&data_json)
635 .map_err(|e| format!("Failed to parse unknown event data: {}", e))?;
636 Ok((id, data))
637 })
638 .collect()
639 }
640
641 pub async fn stream_create(
647 &self,
648 session_id: ClaudeCodeId,
649 ) -> Result<StreamId, ClaudeCodeError> {
650 let stream_id = StreamId::new_v4();
651 let now = current_timestamp();
652
653 let info = StreamInfo {
654 stream_id,
655 session_id,
656 status: StreamStatus::Running,
657 user_position: None,
658 event_count: 0,
659 read_position: 0,
660 started_at: now,
661 ended_at: None,
662 error: None,
663 };
664
665 let buffer = ActiveStreamBuffer {
666 info,
667 events: Vec::new(),
668 };
669
670 let mut streams = self.streams.write().await;
671 streams.insert(stream_id, buffer);
672
673 Ok(stream_id)
674 }
675
676 pub async fn stream_set_user_position(
678 &self,
679 stream_id: &StreamId,
680 position: Position,
681 ) -> Result<(), ClaudeCodeError> {
682 let mut streams = self.streams.write().await;
683 let buffer = streams.get_mut(stream_id)
684 .ok_or_else(|| format!("Stream not found: {}", stream_id))?;
685 buffer.info.user_position = Some(position);
686 Ok(())
687 }
688
689 pub async fn stream_push_event(
691 &self,
692 stream_id: &StreamId,
693 event: ChatEvent,
694 ) -> Result<u64, ClaudeCodeError> {
695 let now = current_timestamp();
696 let mut streams = self.streams.write().await;
697 let buffer = streams.get_mut(stream_id)
698 .ok_or_else(|| format!("Stream not found: {}", stream_id))?;
699
700 let seq = buffer.info.event_count;
701 buffer.events.push(BufferedEvent {
702 seq,
703 event,
704 timestamp: now,
705 });
706 buffer.info.event_count += 1;
707
708 Ok(seq)
709 }
710
711 pub async fn stream_set_status(
713 &self,
714 stream_id: &StreamId,
715 status: StreamStatus,
716 error: Option<String>,
717 ) -> Result<(), ClaudeCodeError> {
718 let now = current_timestamp();
719 let mut streams = self.streams.write().await;
720 let buffer = streams.get_mut(stream_id)
721 .ok_or_else(|| format!("Stream not found: {}", stream_id))?;
722
723 buffer.info.status = status;
724 if status == StreamStatus::Complete || status == StreamStatus::Failed {
725 buffer.info.ended_at = Some(now);
726 }
727 if let Some(e) = error {
728 buffer.info.error = Some(e);
729 }
730
731 Ok(())
732 }
733
734 pub async fn stream_get_info(&self, stream_id: &StreamId) -> Result<StreamInfo, ClaudeCodeError> {
736 let streams = self.streams.read().await;
737 let buffer = streams.get(stream_id)
738 .ok_or_else(|| format!("Stream not found: {}", stream_id))?;
739 Ok(buffer.info.clone())
740 }
741
742 pub async fn stream_poll(
745 &self,
746 stream_id: &StreamId,
747 from_seq: Option<u64>,
748 limit: Option<usize>,
749 ) -> Result<(StreamInfo, Vec<BufferedEvent>), ClaudeCodeError> {
750 let mut streams = self.streams.write().await;
751 let buffer = streams.get_mut(stream_id)
752 .ok_or_else(|| format!("Stream not found: {}", stream_id))?;
753
754 let start = from_seq.unwrap_or(buffer.info.read_position) as usize;
755 let max_events = limit.unwrap_or(100);
756
757 let events: Vec<BufferedEvent> = buffer.events
758 .iter()
759 .skip(start)
760 .take(max_events)
761 .cloned()
762 .collect();
763
764 if !events.is_empty() {
766 let last_seq = events.last().unwrap().seq;
767 buffer.info.read_position = last_seq + 1;
768 }
769
770 Ok((buffer.info.clone(), events))
771 }
772
773 pub async fn stream_list(&self) -> Vec<StreamInfo> {
775 let streams = self.streams.read().await;
776 streams.values().map(|b| b.info.clone()).collect()
777 }
778
779 pub async fn stream_list_for_session(&self, session_id: &ClaudeCodeId) -> Vec<StreamInfo> {
781 let streams = self.streams.read().await;
782 streams
783 .values()
784 .filter(|b| &b.info.session_id == session_id)
785 .map(|b| b.info.clone())
786 .collect()
787 }
788
789 pub async fn stream_cleanup(&self, stream_id: &StreamId) -> Option<StreamInfo> {
792 let mut streams = self.streams.write().await;
793 streams.remove(stream_id).map(|b| b.info)
794 }
795
796 pub async fn stream_exists(&self, stream_id: &StreamId) -> bool {
798 let streams = self.streams.read().await;
799 streams.contains_key(stream_id)
800 }
801
802 fn row_to_message(&self, row: sqlx::sqlite::SqliteRow) -> Result<Message, ClaudeCodeError> {
807 let id_str: String = row.get("id");
808 let session_id_str: String = row.get("session_id");
809 let role_str: String = row.get("role");
810
811 Ok(Message {
812 id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid message ID: {}", e))?,
813 session_id: Uuid::parse_str(&session_id_str)
814 .map_err(|e| format!("Invalid session ID: {}", e))?,
815 role: MessageRole::from_str(&role_str)
816 .ok_or_else(|| format!("Invalid role: {}", role_str))?,
817 content: row.get("content"),
818 created_at: row.get("created_at"),
819 model_id: row.get("model_id"),
820 input_tokens: row.get("input_tokens"),
821 output_tokens: row.get("output_tokens"),
822 cost_usd: row.get("cost_usd"),
823 })
824 }
825
826 fn row_to_config(&self, row: sqlx::sqlite::SqliteRow) -> Result<ClaudeCodeConfig, ClaudeCodeError> {
827 let id_str: String = row.get("id");
828 let tree_id_str: String = row.get("tree_id");
829 let head_str: String = row.get("canonical_head");
830 let model_str: String = row.get("model");
831 let metadata_json: Option<String> = row.get("metadata");
832 let mcp_config_json: Option<String> = row.get("mcp_config");
833 let loopback: i32 = row.get("loopback_enabled");
834
835 let tree_id = TreeId::parse_str(&tree_id_str)
836 .map_err(|e| format!("Invalid tree ID: {}", e))?;
837 let node_id = NodeId::parse_str(&head_str)
838 .map_err(|e| format!("Invalid node ID: {}", e))?;
839 let model = Model::from_str(&model_str)
840 .ok_or_else(|| format!("Invalid model: {}", model_str))?;
841
842 Ok(ClaudeCodeConfig {
843 id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid session ID: {}", e))?,
844 name: row.get("name"),
845 claude_session_id: row.get("claude_session_id"),
846 head: Position::new(tree_id, node_id),
847 working_dir: row.get("working_dir"),
848 model,
849 system_prompt: row.get("system_prompt"),
850 mcp_config: mcp_config_json.and_then(|s| serde_json::from_str(&s).ok()),
851 loopback_enabled: loopback != 0,
852 metadata: metadata_json.and_then(|s| serde_json::from_str(&s).ok()),
853 created_at: row.get("created_at"),
854 updated_at: row.get("updated_at"),
855 })
856 }
857}
858
859fn current_timestamp() -> i64 {
861 SystemTime::now()
862 .duration_since(UNIX_EPOCH)
863 .unwrap()
864 .as_secs() as i64
865}
866
867#[cfg(test)]
868mod tests {
869 use super::*;
870
871 #[tokio::test]
873 async fn test_stream_buffer_operations() {
874 let streams: RwLock<HashMap<StreamId, ActiveStreamBuffer>> = RwLock::new(HashMap::new());
876
877 let stream_id = StreamId::new_v4();
879 let session_id = ClaudeCodeId::new_v4();
880 let now = current_timestamp();
881
882 let info = StreamInfo {
883 stream_id,
884 session_id,
885 status: StreamStatus::Running,
886 user_position: None,
887 event_count: 0,
888 read_position: 0,
889 started_at: now,
890 ended_at: None,
891 error: None,
892 };
893
894 let buffer = ActiveStreamBuffer {
895 info,
896 events: Vec::new(),
897 };
898
899 streams.write().await.insert(stream_id, buffer);
900
901 {
903 let mut streams = streams.write().await;
904 let buffer = streams.get_mut(&stream_id).unwrap();
905
906 buffer.events.push(BufferedEvent {
907 seq: 0,
908 event: ChatEvent::Start {
909 id: session_id,
910 user_position: Position::new(TreeId::new(), NodeId::new()),
911 },
912 timestamp: now,
913 });
914 buffer.info.event_count = 1;
915
916 buffer.events.push(BufferedEvent {
917 seq: 1,
918 event: ChatEvent::Content { text: "Hello".to_string() },
919 timestamp: now,
920 });
921 buffer.info.event_count = 2;
922 }
923
924 {
926 let mut streams = streams.write().await;
927 let buffer = streams.get_mut(&stream_id).unwrap();
928
929 let events: Vec<_> = buffer.events.iter().skip(0).take(10).cloned().collect();
930 assert_eq!(events.len(), 2);
931 assert_eq!(events[0].seq, 0);
932 assert_eq!(events[1].seq, 1);
933
934 buffer.info.read_position = 2;
936 }
937
938 {
940 let streams = streams.read().await;
941 let buffer = streams.get(&stream_id).unwrap();
942
943 let events: Vec<_> = buffer.events.iter()
944 .skip(buffer.info.read_position as usize)
945 .take(10)
946 .collect();
947 assert_eq!(events.len(), 0);
948 }
949
950 {
952 let mut streams = streams.write().await;
953 let buffer = streams.get_mut(&stream_id).unwrap();
954
955 buffer.events.push(BufferedEvent {
956 seq: 2,
957 event: ChatEvent::Content { text: " World".to_string() },
958 timestamp: now,
959 });
960 buffer.info.event_count = 3;
961 }
962
963 {
965 let mut streams = streams.write().await;
966 let buffer = streams.get_mut(&stream_id).unwrap();
967
968 let events: Vec<_> = buffer.events.iter()
969 .skip(buffer.info.read_position as usize)
970 .take(10)
971 .cloned()
972 .collect();
973 assert_eq!(events.len(), 1);
974 assert_eq!(events[0].seq, 2);
975
976 buffer.info.read_position = 3;
978 }
979
980 {
982 let mut streams = streams.write().await;
983 let buffer = streams.get_mut(&stream_id).unwrap();
984
985 assert_eq!(buffer.info.status, StreamStatus::Running);
986
987 buffer.info.status = StreamStatus::AwaitingPermission;
988 assert_eq!(buffer.info.status, StreamStatus::AwaitingPermission);
989
990 buffer.info.status = StreamStatus::Complete;
991 buffer.info.ended_at = Some(current_timestamp());
992 assert_eq!(buffer.info.status, StreamStatus::Complete);
993 assert!(buffer.info.ended_at.is_some());
994 }
995 }
996
997 #[test]
998 fn test_stream_status_serialization() {
999 let status = StreamStatus::AwaitingPermission;
1001 let json = serde_json::to_string(&status).unwrap();
1002 assert_eq!(json, "\"awaiting_permission\"");
1003
1004 let status = StreamStatus::Running;
1005 let json = serde_json::to_string(&status).unwrap();
1006 assert_eq!(json, "\"running\"");
1007 }
1008}