1use super::methods::ConeIdentifier;
2use super::types::{ConeConfig, ConeError, ConeHandle, ConeId, ConeInfo, Message, MessageId, MessageRole, Position};
3use crate::activations::arbor::{ArborStorage, NodeId, TreeId};
4use serde_json::Value;
5use sqlx::{sqlite::{SqliteConnectOptions, SqlitePool}, ConnectOptions, Row};
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9use uuid::Uuid;
10
11#[derive(Debug, Clone)]
13pub struct ConeStorageConfig {
14 pub db_path: PathBuf,
16}
17
18impl Default for ConeStorageConfig {
19 fn default() -> Self {
20 Self {
21 db_path: PathBuf::from("cones.db"),
22 }
23 }
24}
25
26pub struct ConeStorage {
28 pool: SqlitePool,
29 arbor: Arc<ArborStorage>,
30}
31
32impl ConeStorage {
33 pub async fn new(config: ConeStorageConfig, arbor: Arc<ArborStorage>) -> Result<Self, ConeError> {
35 let db_url = format!("sqlite:{}?mode=rwc", config.db_path.display());
37 let connect_options: SqliteConnectOptions = db_url.parse()
38 .map_err(|e| format!("Failed to parse database URL: {}", e))?;
39 let connect_options = connect_options.disable_statement_logging();
40 let pool = SqlitePool::connect_with(connect_options.clone())
41 .await
42 .map_err(|e| format!("Failed to connect to cone database: {}", e))?;
43
44 let storage = Self { pool, arbor };
45 storage.run_migrations().await?;
46
47 Ok(storage)
48 }
49
50 async fn run_migrations(&self) -> Result<(), ConeError> {
52 sqlx::query(
53 r#"
54 CREATE TABLE IF NOT EXISTS cones (
55 id TEXT PRIMARY KEY,
56 name TEXT NOT NULL UNIQUE,
57 model_id TEXT NOT NULL,
58 system_prompt TEXT,
59 tree_id TEXT NOT NULL,
60 canonical_head TEXT NOT NULL,
61 metadata TEXT,
62 created_at INTEGER NOT NULL,
63 updated_at INTEGER NOT NULL
64 );
65
66 CREATE TABLE IF NOT EXISTS messages (
67 id TEXT PRIMARY KEY,
68 cone_id TEXT NOT NULL,
69 role TEXT NOT NULL,
70 content TEXT NOT NULL,
71 model_id TEXT,
72 input_tokens INTEGER,
73 output_tokens INTEGER,
74 created_at INTEGER NOT NULL,
75 FOREIGN KEY (cone_id) REFERENCES cones(id) ON DELETE CASCADE
76 );
77
78 CREATE INDEX IF NOT EXISTS idx_cones_name ON cones(name);
79 CREATE INDEX IF NOT EXISTS idx_cones_tree ON cones(tree_id);
80 CREATE INDEX IF NOT EXISTS idx_messages_cone ON messages(cone_id);
81 "#,
82 )
83 .execute(&self.pool)
84 .await
85 .map_err(|e| format!("Failed to run cone migrations: {}", e))?;
86
87 Ok(())
88 }
89
90 pub fn arbor(&self) -> &ArborStorage {
92 &self.arbor
93 }
94
95 pub async fn cone_create(
104 &self,
105 name: String,
106 model_id: String,
107 system_prompt: Option<String>,
108 metadata: Option<Value>,
109 ) -> Result<ConeConfig, ConeError> {
110 let cone_id = ConeId::new_v4();
111 let now = current_timestamp();
112
113 let tree_id = self.arbor.tree_create(metadata.clone(), &cone_id.to_string()).await
115 .map_err(|e| format!("Failed to create tree for cone: {}", e))?;
116
117 let tree = self.arbor.tree_get(&tree_id).await
119 .map_err(|e| format!("Failed to get tree: {}", e))?;
120 let head = Position::new(tree_id, tree.root);
121
122 let metadata_json = metadata.as_ref().map(|m| serde_json::to_string(m).unwrap());
123
124 let final_name = match sqlx::query(
126 "INSERT INTO cones (id, name, model_id, system_prompt, tree_id, canonical_head, metadata, created_at, updated_at)
127 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
128 )
129 .bind(cone_id.to_string())
130 .bind(&name)
131 .bind(&model_id)
132 .bind(&system_prompt)
133 .bind(head.tree_id.to_string())
134 .bind(head.node_id.to_string())
135 .bind(metadata_json.clone())
136 .bind(now)
137 .bind(now)
138 .execute(&self.pool)
139 .await {
140 Ok(_) => name, Err(e) if e.to_string().contains("UNIQUE constraint failed") => {
142 let unique_name = format!("{}#{}", name, cone_id);
144
145 sqlx::query(
146 "INSERT INTO cones (id, name, model_id, system_prompt, tree_id, canonical_head, metadata, created_at, updated_at)
147 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
148 )
149 .bind(cone_id.to_string())
150 .bind(&unique_name)
151 .bind(&model_id)
152 .bind(&system_prompt)
153 .bind(head.tree_id.to_string())
154 .bind(head.node_id.to_string())
155 .bind(metadata_json)
156 .bind(now)
157 .bind(now)
158 .execute(&self.pool)
159 .await
160 .map_err(|e| format!("Failed to create cone with unique name: {}", e))?;
161
162 unique_name
163 }
164 Err(e) => return Err(ConeError::from(format!("Failed to create cone: {}", e))),
165 };
166
167 Ok(ConeConfig {
168 id: cone_id,
169 name: final_name,
170 model_id,
171 system_prompt,
172 head,
173 metadata,
174 created_at: now,
175 updated_at: now,
176 })
177 }
178
179 pub async fn resolve_cone_identifier(&self, identifier: &ConeIdentifier) -> Result<ConeId, ConeError> {
187 match identifier {
188 ConeIdentifier::ById { id } => Ok(*id),
189 ConeIdentifier::ByName { name } => {
190 if let Some(row) = sqlx::query("SELECT id FROM cones WHERE name = ?")
192 .bind(name)
193 .fetch_optional(&self.pool)
194 .await
195 .map_err(|e| ConeError::from(format!("Failed to resolve cone by name: {}", e)))?
196 {
197 let id_str: String = row.get("id");
198 return Uuid::parse_str(&id_str)
199 .map_err(|e| ConeError::from(format!("Invalid cone ID in database: {}", e)));
200 }
201
202 let pattern = format!("{}%", name);
205 let rows = sqlx::query("SELECT id, name FROM cones WHERE name LIKE ?")
206 .bind(&pattern)
207 .fetch_all(&self.pool)
208 .await
209 .map_err(|e| ConeError::from(format!("Failed to resolve cone by pattern: {}", e)))?;
210
211 match rows.len() {
212 0 => Err(ConeError::from(format!("Cone not found with name: {}", name))),
213 1 => {
214 let id_str: String = rows[0].get("id");
215 Uuid::parse_str(&id_str)
216 .map_err(|e| ConeError::from(format!("Invalid cone ID in database: {}", e)))
217 }
218 _ => {
219 let matches: Vec<String> = rows.iter().map(|r| r.get("name")).collect();
221 Err(ConeError::from(format!(
222 "Ambiguous name '{}' matches multiple cones: {}. Use full name with #uuid to disambiguate.",
223 name,
224 matches.join(", ")
225 )))
226 }
227 }
228 }
229 }
230 }
231
232 pub async fn cone_get(&self, cone_id: &ConeId) -> Result<ConeConfig, ConeError> {
234 let row = sqlx::query(
235 "SELECT id, name, model_id, system_prompt, tree_id, canonical_head, metadata, created_at, updated_at
236 FROM cones WHERE id = ?",
237 )
238 .bind(cone_id.to_string())
239 .fetch_optional(&self.pool)
240 .await
241 .map_err(|e| format!("Failed to fetch cone: {}", e))?
242 .ok_or_else(|| format!("Cone not found: {}", cone_id))?;
243
244 self.row_to_cone_config(row)
245 }
246
247 pub async fn cone_get_by_identifier(&self, identifier: &ConeIdentifier) -> Result<ConeConfig, ConeError> {
249 let cone_id = self.resolve_cone_identifier(identifier).await?;
250 self.cone_get(&cone_id).await
251 }
252
253 pub async fn cone_list(&self) -> Result<Vec<ConeInfo>, ConeError> {
255 let rows = sqlx::query(
256 "SELECT id, name, model_id, tree_id, canonical_head, created_at FROM cones ORDER BY created_at DESC",
257 )
258 .fetch_all(&self.pool)
259 .await
260 .map_err(|e| format!("Failed to list cones: {}", e))?;
261
262 let cones: Result<Vec<ConeInfo>, ConeError> = rows
263 .iter()
264 .map(|row| {
265 let id_str: String = row.get("id");
266 let tree_id_str: String = row.get("tree_id");
267 let head_str: String = row.get("canonical_head");
268
269 let tree_id = TreeId::parse_str(&tree_id_str).map_err(|e| format!("Invalid tree ID: {}", e))?;
270 let node_id = NodeId::parse_str(&head_str).map_err(|e| format!("Invalid node ID: {}", e))?;
271
272 Ok(ConeInfo {
273 id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid cone ID: {}", e))?,
274 name: row.get("name"),
275 model_id: row.get("model_id"),
276 head: Position::new(tree_id, node_id),
277 created_at: row.get("created_at"),
278 })
279 })
280 .collect();
281
282 cones
283 }
284
285 pub async fn cone_update_head(
287 &self,
288 cone_id: &ConeId,
289 new_head: NodeId,
290 ) -> Result<(), ConeError> {
291 let now = current_timestamp();
292
293 let result = sqlx::query(
294 "UPDATE cones SET canonical_head = ?, updated_at = ? WHERE id = ?",
295 )
296 .bind(new_head.to_string())
297 .bind(now)
298 .bind(cone_id.to_string())
299 .execute(&self.pool)
300 .await
301 .map_err(|e| format!("Failed to update cone head: {}", e))?;
302
303 if result.rows_affected() == 0 {
304 return Err(format!("Cone not found: {}", cone_id).into());
305 }
306
307 Ok(())
308 }
309
310 pub async fn cone_update(
312 &self,
313 cone_id: &ConeId,
314 name: Option<String>,
315 model_id: Option<String>,
316 system_prompt: Option<Option<String>>,
317 metadata: Option<Value>,
318 ) -> Result<(), ConeError> {
319 let now = current_timestamp();
320
321 let current = self.cone_get(cone_id).await?;
323
324 let new_name = name.unwrap_or(current.name);
325 let new_model = model_id.unwrap_or(current.model_id);
326 let new_prompt = system_prompt.unwrap_or(current.system_prompt);
327 let new_metadata = metadata.or(current.metadata);
328 let metadata_json = new_metadata.as_ref().map(|m| serde_json::to_string(m).unwrap());
329
330 sqlx::query(
331 "UPDATE cones SET name = ?, model_id = ?, system_prompt = ?, metadata = ?, updated_at = ? WHERE id = ?",
332 )
333 .bind(&new_name)
334 .bind(&new_model)
335 .bind(&new_prompt)
336 .bind(metadata_json)
337 .bind(now)
338 .bind(cone_id.to_string())
339 .execute(&self.pool)
340 .await
341 .map_err(|e| format!("Failed to update cone: {}", e))?;
342
343 Ok(())
344 }
345
346 pub async fn cone_delete(&self, cone_id: &ConeId) -> Result<(), ConeError> {
348 let result = sqlx::query("DELETE FROM cones WHERE id = ?")
349 .bind(cone_id.to_string())
350 .execute(&self.pool)
351 .await
352 .map_err(|e| format!("Failed to delete cone: {}", e))?;
353
354 if result.rows_affected() == 0 {
355 return Err(format!("Cone not found: {}", cone_id).into());
356 }
357
358 Ok(())
359 }
360
361 pub async fn message_create(
367 &self,
368 cone_id: &ConeId,
369 role: MessageRole,
370 content: String,
371 model_id: Option<String>,
372 input_tokens: Option<i64>,
373 output_tokens: Option<i64>,
374 ) -> Result<Message, ConeError> {
375 let message_id = MessageId::new_v4();
376 let now = current_timestamp();
377
378 sqlx::query(
379 "INSERT INTO messages (id, cone_id, role, content, model_id, input_tokens, output_tokens, created_at)
380 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
381 )
382 .bind(message_id.to_string())
383 .bind(cone_id.to_string())
384 .bind(role.as_str())
385 .bind(&content)
386 .bind(&model_id)
387 .bind(input_tokens)
388 .bind(output_tokens)
389 .bind(now)
390 .execute(&self.pool)
391 .await
392 .map_err(|e| format!("Failed to create message: {}", e))?;
393
394 Ok(Message {
395 id: message_id,
396 cone_id: *cone_id,
397 role,
398 content,
399 created_at: now,
400 model_id,
401 input_tokens,
402 output_tokens,
403 })
404 }
405
406 pub async fn message_create_ephemeral(
408 &self,
409 cone_id: &ConeId,
410 role: MessageRole,
411 content: String,
412 model_id: Option<String>,
413 input_tokens: Option<i64>,
414 output_tokens: Option<i64>,
415 ) -> Result<Message, ConeError> {
416 let message_id = MessageId::new_v4();
417 let now = current_timestamp();
418
419 let ephemeral_marker = -now;
421
422 sqlx::query(
423 "INSERT INTO messages (id, cone_id, role, content, model_id, input_tokens, output_tokens, created_at)
424 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
425 )
426 .bind(message_id.to_string())
427 .bind(cone_id.to_string())
428 .bind(role.as_str())
429 .bind(&content)
430 .bind(&model_id)
431 .bind(input_tokens)
432 .bind(output_tokens)
433 .bind(ephemeral_marker)
434 .execute(&self.pool)
435 .await
436 .map_err(|e| format!("Failed to create ephemeral message: {}", e))?;
437
438 Ok(Message {
439 id: message_id,
440 cone_id: *cone_id,
441 role,
442 content,
443 created_at: ephemeral_marker,
444 model_id,
445 input_tokens,
446 output_tokens,
447 })
448 }
449
450 pub async fn message_get(&self, message_id: &MessageId) -> Result<Message, ConeError> {
452 let row = sqlx::query(
453 "SELECT id, cone_id, role, content, model_id, input_tokens, output_tokens, created_at
454 FROM messages WHERE id = ?",
455 )
456 .bind(message_id.to_string())
457 .fetch_optional(&self.pool)
458 .await
459 .map_err(|e| format!("Failed to fetch message: {}", e))?
460 .ok_or_else(|| format!("Message not found: {}", message_id))?;
461
462 self.row_to_message(row)
463 }
464
465 pub async fn resolve_message_handle(&self, identifier: &str) -> Result<Message, ConeError> {
468 let parts: Vec<&str> = identifier.splitn(3, ':').collect();
470 if parts.len() < 2 {
471 return Err(format!("Invalid message handle format: {}", identifier).into());
472 }
473
474 let msg_part = parts[0];
475 if !msg_part.starts_with("msg-") {
476 return Err(format!("Invalid message handle format: {}", identifier).into());
477 }
478
479 let message_id_str = &msg_part[4..]; let message_id = Uuid::parse_str(message_id_str)
481 .map_err(|e| format!("Invalid message ID in handle: {}", e))?;
482
483 self.message_get(&message_id).await
484 }
485
486 pub fn message_to_handle(message: &Message, name: &str) -> crate::types::Handle {
491 ConeHandle::Message {
492 message_id: format!("msg-{}", message.id),
493 role: message.role.as_str().to_string(),
494 name: name.to_string(),
495 }.to_handle()
496 }
497
498 fn row_to_message(&self, row: sqlx::sqlite::SqliteRow) -> Result<Message, ConeError> {
503 let id_str: String = row.get("id");
504 let cone_id_str: String = row.get("cone_id");
505 let role_str: String = row.get("role");
506
507 Ok(Message {
508 id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid message ID: {}", e))?,
509 cone_id: Uuid::parse_str(&cone_id_str).map_err(|e| format!("Invalid cone ID: {}", e))?,
510 role: MessageRole::from_str(&role_str).ok_or_else(|| format!("Invalid role: {}", role_str))?,
511 content: row.get("content"),
512 created_at: row.get("created_at"),
513 model_id: row.get("model_id"),
514 input_tokens: row.get("input_tokens"),
515 output_tokens: row.get("output_tokens"),
516 })
517 }
518
519 fn row_to_cone_config(&self, row: sqlx::sqlite::SqliteRow) -> Result<ConeConfig, ConeError> {
520 let id_str: String = row.get("id");
521 let tree_id_str: String = row.get("tree_id");
522 let head_str: String = row.get("canonical_head");
523 let metadata_json: Option<String> = row.get("metadata");
524
525 let tree_id = TreeId::parse_str(&tree_id_str).map_err(|e| format!("Invalid tree ID: {}", e))?;
526 let node_id = NodeId::parse_str(&head_str).map_err(|e| format!("Invalid node ID: {}", e))?;
527
528 Ok(ConeConfig {
529 id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid cone ID: {}", e))?,
530 name: row.get("name"),
531 model_id: row.get("model_id"),
532 system_prompt: row.get("system_prompt"),
533 head: Position::new(tree_id, node_id),
534 metadata: metadata_json.and_then(|s| serde_json::from_str(&s).ok()),
535 created_at: row.get("created_at"),
536 updated_at: row.get("updated_at"),
537 })
538 }
539}
540
541fn current_timestamp() -> i64 {
543 SystemTime::now()
544 .duration_since(UNIX_EPOCH)
545 .unwrap()
546 .as_secs() as i64
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use super::super::Cone;
553
554 #[test]
563 fn invariant_handle_meta_format_matches_resolver() {
564 let message = Message {
566 id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
567 cone_id: Uuid::new_v4(),
568 role: MessageRole::User,
569 content: "test content".to_string(),
570 created_at: 0,
571 model_id: None,
572 input_tokens: None,
573 output_tokens: None,
574 };
575
576 let handle = ConeStorage::message_to_handle(&message, "test-cone");
578
579 let identifier = handle.meta.join(":");
581
582 let parts: Vec<&str> = identifier.splitn(3, ':').collect();
585
586 assert_eq!(parts.len(), 3, "identifier should have 3 parts: {}", identifier);
587 assert!(parts[0].starts_with("msg-"), "first part should start with 'msg-': {}", parts[0]);
588
589 let msg_part = parts[0];
591 let message_id_str = &msg_part[4..]; let parsed_id = Uuid::parse_str(message_id_str);
593 assert!(parsed_id.is_ok(), "should be able to parse UUID from meta[0]");
594 assert_eq!(parsed_id.unwrap(), message.id);
595
596 assert_eq!(parts[1], "user");
598
599 assert_eq!(parts[2], "test-cone");
601 }
602
603 #[test]
604 fn invariant_handle_meta_roles() {
605 for (role, expected_str) in [
607 (MessageRole::User, "user"),
608 (MessageRole::Assistant, "assistant"),
609 (MessageRole::System, "system"),
610 ] {
611 let message = Message {
612 id: Uuid::new_v4(),
613 cone_id: Uuid::new_v4(),
614 role,
615 content: "test".to_string(),
616 created_at: 0,
617 model_id: None,
618 input_tokens: None,
619 output_tokens: None,
620 };
621
622 let handle = ConeStorage::message_to_handle(&message, "cone");
623 assert_eq!(handle.meta[1], expected_str);
624 }
625 }
626
627 #[test]
628 fn invariant_handle_plugin_method_fixed() {
629 let message = Message {
631 id: Uuid::new_v4(),
632 cone_id: Uuid::new_v4(),
633 role: MessageRole::User,
634 content: "test".to_string(),
635 created_at: 0,
636 model_id: None,
637 input_tokens: None,
638 output_tokens: None,
639 };
640
641 let handle = ConeStorage::message_to_handle(&message, "any-name");
642
643 assert_eq!(handle.plugin_id, Cone::PLUGIN_ID);
645 assert_eq!(handle.version, "1.0.0");
646 assert_eq!(handle.method, "chat");
647 }
648
649 #[test]
650 fn invariant_handle_meta_has_three_parts() {
651 let message = Message {
653 id: Uuid::new_v4(),
654 cone_id: Uuid::new_v4(),
655 role: MessageRole::Assistant,
656 content: "response".to_string(),
657 created_at: 0,
658 model_id: Some("gpt-4".to_string()),
659 input_tokens: Some(10),
660 output_tokens: Some(20),
661 };
662
663 let handle = ConeStorage::message_to_handle(&message, "my-cone");
664
665 assert_eq!(handle.meta.len(), 3, "cone chat handle must have exactly 3 meta parts");
666 }
667}