Skip to main content

plexus_substrate/activations/cone/
storage.rs

1use super::methods::ConeIdentifier;
2use super::types::{ConeConfig, ConeError, ConeHandle, ConeId, ConeInfo, Message, MessageId, MessageRole, Position};
3use crate::activations::arbor::{ArborStorage, NodeId, TreeId};
4use serde_json::Value;
5use sqlx::{sqlite::{SqliteConnectOptions, SqlitePool}, ConnectOptions, Row};
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9use uuid::Uuid;
10
11/// Configuration for Cone storage
12#[derive(Debug, Clone)]
13pub struct ConeStorageConfig {
14    /// Path to SQLite database for cone configs
15    pub db_path: PathBuf,
16}
17
18impl Default for ConeStorageConfig {
19    fn default() -> Self {
20        Self {
21            db_path: PathBuf::from("cones.db"),
22        }
23    }
24}
25
26/// Storage layer for cone configurations
27pub struct ConeStorage {
28    pool: SqlitePool,
29    arbor: Arc<ArborStorage>,
30}
31
32impl ConeStorage {
33    /// Create a new cone storage instance with a shared Arbor storage
34    pub async fn new(config: ConeStorageConfig, arbor: Arc<ArborStorage>) -> Result<Self, ConeError> {
35        // Initialize cone database
36        let db_url = format!("sqlite:{}?mode=rwc", config.db_path.display());
37        let connect_options: SqliteConnectOptions = db_url.parse()
38            .map_err(|e| format!("Failed to parse database URL: {}", e))?;
39        let connect_options = connect_options.disable_statement_logging();
40        let pool = SqlitePool::connect_with(connect_options.clone())
41            .await
42            .map_err(|e| format!("Failed to connect to cone database: {}", e))?;
43
44        let storage = Self { pool, arbor };
45        storage.run_migrations().await?;
46
47        Ok(storage)
48    }
49
50    /// Run database migrations
51    async fn run_migrations(&self) -> Result<(), ConeError> {
52        sqlx::query(
53            r#"
54            CREATE TABLE IF NOT EXISTS cones (
55                id TEXT PRIMARY KEY,
56                name TEXT NOT NULL UNIQUE,
57                model_id TEXT NOT NULL,
58                system_prompt TEXT,
59                tree_id TEXT NOT NULL,
60                canonical_head TEXT NOT NULL,
61                metadata TEXT,
62                created_at INTEGER NOT NULL,
63                updated_at INTEGER NOT NULL
64            );
65
66            CREATE TABLE IF NOT EXISTS messages (
67                id TEXT PRIMARY KEY,
68                cone_id TEXT NOT NULL,
69                role TEXT NOT NULL,
70                content TEXT NOT NULL,
71                model_id TEXT,
72                input_tokens INTEGER,
73                output_tokens INTEGER,
74                created_at INTEGER NOT NULL,
75                FOREIGN KEY (cone_id) REFERENCES cones(id) ON DELETE CASCADE
76            );
77
78            CREATE INDEX IF NOT EXISTS idx_cones_name ON cones(name);
79            CREATE INDEX IF NOT EXISTS idx_cones_tree ON cones(tree_id);
80            CREATE INDEX IF NOT EXISTS idx_messages_cone ON messages(cone_id);
81            "#,
82        )
83        .execute(&self.pool)
84        .await
85        .map_err(|e| format!("Failed to run cone migrations: {}", e))?;
86
87        Ok(())
88    }
89
90    /// Get access to the underlying arbor storage
91    pub fn arbor(&self) -> &ArborStorage {
92        &self.arbor
93    }
94
95    // ========================================================================
96    // Cone CRUD Operations
97    // ========================================================================
98
99    /// Create a new cone with a new conversation tree
100    ///
101    /// If a cone with the given name already exists, automatically appends `#<uuid>`
102    /// to make it unique. For example, "assistant" becomes "assistant#550e8400..."
103    pub async fn cone_create(
104        &self,
105        name: String,
106        model_id: String,
107        system_prompt: Option<String>,
108        metadata: Option<Value>,
109    ) -> Result<ConeConfig, ConeError> {
110        let cone_id = ConeId::new_v4();
111        let now = current_timestamp();
112
113        // Create a new tree for this cone
114        let tree_id = self.arbor.tree_create(metadata.clone(), &cone_id.to_string()).await
115            .map_err(|e| format!("Failed to create tree for cone: {}", e))?;
116
117        // Get the root node as initial position
118        let tree = self.arbor.tree_get(&tree_id).await
119            .map_err(|e| format!("Failed to get tree: {}", e))?;
120        let head = Position::new(tree_id, tree.root);
121
122        let metadata_json = metadata.as_ref().map(|m| serde_json::to_string(m).unwrap());
123
124        // Try inserting with the original name first
125        let final_name = match sqlx::query(
126            "INSERT INTO cones (id, name, model_id, system_prompt, tree_id, canonical_head, metadata, created_at, updated_at)
127             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
128        )
129        .bind(cone_id.to_string())
130        .bind(&name)
131        .bind(&model_id)
132        .bind(&system_prompt)
133        .bind(head.tree_id.to_string())
134        .bind(head.node_id.to_string())
135        .bind(metadata_json.clone())
136        .bind(now)
137        .bind(now)
138        .execute(&self.pool)
139        .await {
140            Ok(_) => name,  // Success with original name
141            Err(e) if e.to_string().contains("UNIQUE constraint failed") => {
142                // Name collision - append #uuid to make it unique
143                let unique_name = format!("{}#{}", name, cone_id);
144
145                sqlx::query(
146                    "INSERT INTO cones (id, name, model_id, system_prompt, tree_id, canonical_head, metadata, created_at, updated_at)
147                     VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
148                )
149                .bind(cone_id.to_string())
150                .bind(&unique_name)
151                .bind(&model_id)
152                .bind(&system_prompt)
153                .bind(head.tree_id.to_string())
154                .bind(head.node_id.to_string())
155                .bind(metadata_json)
156                .bind(now)
157                .bind(now)
158                .execute(&self.pool)
159                .await
160                .map_err(|e| format!("Failed to create cone with unique name: {}", e))?;
161
162                unique_name
163            }
164            Err(e) => return Err(ConeError::from(format!("Failed to create cone: {}", e))),
165        };
166
167        Ok(ConeConfig {
168            id: cone_id,
169            name: final_name,
170            model_id,
171            system_prompt,
172            head,
173            metadata,
174            created_at: now,
175            updated_at: now,
176        })
177    }
178
179    /// Resolve a cone identifier to a ConeId
180    ///
181    /// For name lookups, supports partial matching on the name portion before '#':
182    /// - "assistant" matches "assistant" or "assistant#550e8400-..."
183    /// - "assistant#550e" matches "assistant#550e8400-..."
184    ///
185    /// Fails if the pattern matches multiple cones (ambiguous).
186    pub async fn resolve_cone_identifier(&self, identifier: &ConeIdentifier) -> Result<ConeId, ConeError> {
187        match identifier {
188            ConeIdentifier::ById { id } => Ok(*id),
189            ConeIdentifier::ByName { name } => {
190                // Try exact match first
191                if let Some(row) = sqlx::query("SELECT id FROM cones WHERE name = ?")
192                    .bind(name)
193                    .fetch_optional(&self.pool)
194                    .await
195                    .map_err(|e| ConeError::from(format!("Failed to resolve cone by name: {}", e)))?
196                {
197                    let id_str: String = row.get("id");
198                    return Uuid::parse_str(&id_str)
199                        .map_err(|e| ConeError::from(format!("Invalid cone ID in database: {}", e)));
200                }
201
202                // Try partial match with LIKE pattern
203                // Pattern: "name%" matches "name" or "name#uuid"
204                let pattern = format!("{}%", name);
205                let rows = sqlx::query("SELECT id, name FROM cones WHERE name LIKE ?")
206                    .bind(&pattern)
207                    .fetch_all(&self.pool)
208                    .await
209                    .map_err(|e| ConeError::from(format!("Failed to resolve cone by pattern: {}", e)))?;
210
211                match rows.len() {
212                    0 => Err(ConeError::from(format!("Cone not found with name: {}", name))),
213                    1 => {
214                        let id_str: String = rows[0].get("id");
215                        Uuid::parse_str(&id_str)
216                            .map_err(|e| ConeError::from(format!("Invalid cone ID in database: {}", e)))
217                    }
218                    _ => {
219                        // Multiple matches - list them for user
220                        let matches: Vec<String> = rows.iter().map(|r| r.get("name")).collect();
221                        Err(ConeError::from(format!(
222                            "Ambiguous name '{}' matches multiple cones: {}. Use full name with #uuid to disambiguate.",
223                            name,
224                            matches.join(", ")
225                        )))
226                    }
227                }
228            }
229        }
230    }
231
232    /// Get a cone by ID
233    pub async fn cone_get(&self, cone_id: &ConeId) -> Result<ConeConfig, ConeError> {
234        let row = sqlx::query(
235            "SELECT id, name, model_id, system_prompt, tree_id, canonical_head, metadata, created_at, updated_at
236             FROM cones WHERE id = ?",
237        )
238        .bind(cone_id.to_string())
239        .fetch_optional(&self.pool)
240        .await
241        .map_err(|e| format!("Failed to fetch cone: {}", e))?
242        .ok_or_else(|| format!("Cone not found: {}", cone_id))?;
243
244        self.row_to_cone_config(row)
245    }
246
247    /// Get a cone by identifier (name or UUID)
248    pub async fn cone_get_by_identifier(&self, identifier: &ConeIdentifier) -> Result<ConeConfig, ConeError> {
249        let cone_id = self.resolve_cone_identifier(identifier).await?;
250        self.cone_get(&cone_id).await
251    }
252
253    /// List all cones
254    pub async fn cone_list(&self) -> Result<Vec<ConeInfo>, ConeError> {
255        let rows = sqlx::query(
256            "SELECT id, name, model_id, tree_id, canonical_head, created_at FROM cones ORDER BY created_at DESC",
257        )
258        .fetch_all(&self.pool)
259        .await
260        .map_err(|e| format!("Failed to list cones: {}", e))?;
261
262        let cones: Result<Vec<ConeInfo>, ConeError> = rows
263            .iter()
264            .map(|row| {
265                let id_str: String = row.get("id");
266                let tree_id_str: String = row.get("tree_id");
267                let head_str: String = row.get("canonical_head");
268
269                let tree_id = TreeId::parse_str(&tree_id_str).map_err(|e| format!("Invalid tree ID: {}", e))?;
270                let node_id = NodeId::parse_str(&head_str).map_err(|e| format!("Invalid node ID: {}", e))?;
271
272                Ok(ConeInfo {
273                    id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid cone ID: {}", e))?,
274                    name: row.get("name"),
275                    model_id: row.get("model_id"),
276                    head: Position::new(tree_id, node_id),
277                    created_at: row.get("created_at"),
278                })
279            })
280            .collect();
281
282        cones
283    }
284
285    /// Update cone's canonical head
286    pub async fn cone_update_head(
287        &self,
288        cone_id: &ConeId,
289        new_head: NodeId,
290    ) -> Result<(), ConeError> {
291        let now = current_timestamp();
292
293        let result = sqlx::query(
294            "UPDATE cones SET canonical_head = ?, updated_at = ? WHERE id = ?",
295        )
296        .bind(new_head.to_string())
297        .bind(now)
298        .bind(cone_id.to_string())
299        .execute(&self.pool)
300        .await
301        .map_err(|e| format!("Failed to update cone head: {}", e))?;
302
303        if result.rows_affected() == 0 {
304            return Err(format!("Cone not found: {}", cone_id).into());
305        }
306
307        Ok(())
308    }
309
310    /// Update cone configuration
311    pub async fn cone_update(
312        &self,
313        cone_id: &ConeId,
314        name: Option<String>,
315        model_id: Option<String>,
316        system_prompt: Option<Option<String>>,
317        metadata: Option<Value>,
318    ) -> Result<(), ConeError> {
319        let now = current_timestamp();
320
321        // Get current cone
322        let current = self.cone_get(cone_id).await?;
323
324        let new_name = name.unwrap_or(current.name);
325        let new_model = model_id.unwrap_or(current.model_id);
326        let new_prompt = system_prompt.unwrap_or(current.system_prompt);
327        let new_metadata = metadata.or(current.metadata);
328        let metadata_json = new_metadata.as_ref().map(|m| serde_json::to_string(m).unwrap());
329
330        sqlx::query(
331            "UPDATE cones SET name = ?, model_id = ?, system_prompt = ?, metadata = ?, updated_at = ? WHERE id = ?",
332        )
333        .bind(&new_name)
334        .bind(&new_model)
335        .bind(&new_prompt)
336        .bind(metadata_json)
337        .bind(now)
338        .bind(cone_id.to_string())
339        .execute(&self.pool)
340        .await
341        .map_err(|e| format!("Failed to update cone: {}", e))?;
342
343        Ok(())
344    }
345
346    /// Delete an cone (does not delete the tree)
347    pub async fn cone_delete(&self, cone_id: &ConeId) -> Result<(), ConeError> {
348        let result = sqlx::query("DELETE FROM cones WHERE id = ?")
349            .bind(cone_id.to_string())
350            .execute(&self.pool)
351            .await
352            .map_err(|e| format!("Failed to delete cone: {}", e))?;
353
354        if result.rows_affected() == 0 {
355            return Err(format!("Cone not found: {}", cone_id).into());
356        }
357
358        Ok(())
359    }
360
361    // ========================================================================
362    // Message Operations
363    // ========================================================================
364
365    /// Create a message and return its ID
366    pub async fn message_create(
367        &self,
368        cone_id: &ConeId,
369        role: MessageRole,
370        content: String,
371        model_id: Option<String>,
372        input_tokens: Option<i64>,
373        output_tokens: Option<i64>,
374    ) -> Result<Message, ConeError> {
375        let message_id = MessageId::new_v4();
376        let now = current_timestamp();
377
378        sqlx::query(
379            "INSERT INTO messages (id, cone_id, role, content, model_id, input_tokens, output_tokens, created_at)
380             VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
381        )
382        .bind(message_id.to_string())
383        .bind(cone_id.to_string())
384        .bind(role.as_str())
385        .bind(&content)
386        .bind(&model_id)
387        .bind(input_tokens)
388        .bind(output_tokens)
389        .bind(now)
390        .execute(&self.pool)
391        .await
392        .map_err(|e| format!("Failed to create message: {}", e))?;
393
394        Ok(Message {
395            id: message_id,
396            cone_id: *cone_id,
397            role,
398            content,
399            created_at: now,
400            model_id,
401            input_tokens,
402            output_tokens,
403        })
404    }
405
406    /// Create an ephemeral message (marked for deletion) and return it
407    pub async fn message_create_ephemeral(
408        &self,
409        cone_id: &ConeId,
410        role: MessageRole,
411        content: String,
412        model_id: Option<String>,
413        input_tokens: Option<i64>,
414        output_tokens: Option<i64>,
415    ) -> Result<Message, ConeError> {
416        let message_id = MessageId::new_v4();
417        let now = current_timestamp();
418
419        // Use negative timestamp as ephemeral marker for cleanup
420        let ephemeral_marker = -now;
421
422        sqlx::query(
423            "INSERT INTO messages (id, cone_id, role, content, model_id, input_tokens, output_tokens, created_at)
424             VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
425        )
426        .bind(message_id.to_string())
427        .bind(cone_id.to_string())
428        .bind(role.as_str())
429        .bind(&content)
430        .bind(&model_id)
431        .bind(input_tokens)
432        .bind(output_tokens)
433        .bind(ephemeral_marker)
434        .execute(&self.pool)
435        .await
436        .map_err(|e| format!("Failed to create ephemeral message: {}", e))?;
437
438        Ok(Message {
439            id: message_id,
440            cone_id: *cone_id,
441            role,
442            content,
443            created_at: ephemeral_marker,
444            model_id,
445            input_tokens,
446            output_tokens,
447        })
448    }
449
450    /// Get a message by ID
451    pub async fn message_get(&self, message_id: &MessageId) -> Result<Message, ConeError> {
452        let row = sqlx::query(
453            "SELECT id, cone_id, role, content, model_id, input_tokens, output_tokens, created_at
454             FROM messages WHERE id = ?",
455        )
456        .bind(message_id.to_string())
457        .fetch_optional(&self.pool)
458        .await
459        .map_err(|e| format!("Failed to fetch message: {}", e))?
460        .ok_or_else(|| format!("Message not found: {}", message_id))?;
461
462        self.row_to_message(row)
463    }
464
465    /// Resolve a message handle identifier to a Message
466    /// Handle format: "msg-{message_id}:{role}:{name}"
467    pub async fn resolve_message_handle(&self, identifier: &str) -> Result<Message, ConeError> {
468        // Parse identifier: "msg-{uuid}:{role}:{name}"
469        let parts: Vec<&str> = identifier.splitn(3, ':').collect();
470        if parts.len() < 2 {
471            return Err(format!("Invalid message handle format: {}", identifier).into());
472        }
473
474        let msg_part = parts[0];
475        if !msg_part.starts_with("msg-") {
476            return Err(format!("Invalid message handle format: {}", identifier).into());
477        }
478
479        let message_id_str = &msg_part[4..]; // Strip "msg-" prefix
480        let message_id = Uuid::parse_str(message_id_str)
481            .map_err(|e| format!("Invalid message ID in handle: {}", e))?;
482
483        self.message_get(&message_id).await
484    }
485
486    /// Create a handle for a message
487    ///
488    /// Format: `{plugin_id}@1.0.0::chat:msg-{id}:{role}:{name}`
489    /// Uses ConeHandle enum for type-safe handle creation.
490    pub fn message_to_handle(message: &Message, name: &str) -> crate::types::Handle {
491        ConeHandle::Message {
492            message_id: format!("msg-{}", message.id),
493            role: message.role.as_str().to_string(),
494            name: name.to_string(),
495        }.to_handle()
496    }
497
498    // ========================================================================
499    // Helper methods
500    // ========================================================================
501
502    fn row_to_message(&self, row: sqlx::sqlite::SqliteRow) -> Result<Message, ConeError> {
503        let id_str: String = row.get("id");
504        let cone_id_str: String = row.get("cone_id");
505        let role_str: String = row.get("role");
506
507        Ok(Message {
508            id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid message ID: {}", e))?,
509            cone_id: Uuid::parse_str(&cone_id_str).map_err(|e| format!("Invalid cone ID: {}", e))?,
510            role: MessageRole::from_str(&role_str).ok_or_else(|| format!("Invalid role: {}", role_str))?,
511            content: row.get("content"),
512            created_at: row.get("created_at"),
513            model_id: row.get("model_id"),
514            input_tokens: row.get("input_tokens"),
515            output_tokens: row.get("output_tokens"),
516        })
517    }
518
519    fn row_to_cone_config(&self, row: sqlx::sqlite::SqliteRow) -> Result<ConeConfig, ConeError> {
520        let id_str: String = row.get("id");
521        let tree_id_str: String = row.get("tree_id");
522        let head_str: String = row.get("canonical_head");
523        let metadata_json: Option<String> = row.get("metadata");
524
525        let tree_id = TreeId::parse_str(&tree_id_str).map_err(|e| format!("Invalid tree ID: {}", e))?;
526        let node_id = NodeId::parse_str(&head_str).map_err(|e| format!("Invalid node ID: {}", e))?;
527
528        Ok(ConeConfig {
529            id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid cone ID: {}", e))?,
530            name: row.get("name"),
531            model_id: row.get("model_id"),
532            system_prompt: row.get("system_prompt"),
533            head: Position::new(tree_id, node_id),
534            metadata: metadata_json.and_then(|s| serde_json::from_str(&s).ok()),
535            created_at: row.get("created_at"),
536            updated_at: row.get("updated_at"),
537        })
538    }
539}
540
541/// Get current Unix timestamp in seconds
542fn current_timestamp() -> i64 {
543    SystemTime::now()
544        .duration_since(UNIX_EPOCH)
545        .unwrap()
546        .as_secs() as i64
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use super::super::Cone;
553
554    // ========================================================================
555    // INVARIANT: Handle meta format consistency
556    //
557    // This is the critical invariant that was violated before the fix.
558    // The meta parts created by message_to_handle, when joined with ':',
559    // must match the format expected by resolve_message_handle.
560    // ========================================================================
561
562    #[test]
563    fn invariant_handle_meta_format_matches_resolver() {
564        // Create a mock message
565        let message = Message {
566            id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
567            cone_id: Uuid::new_v4(),
568            role: MessageRole::User,
569            content: "test content".to_string(),
570            created_at: 0,
571            model_id: None,
572            input_tokens: None,
573            output_tokens: None,
574        };
575
576        // Create handle the way cone::chat does
577        let handle = ConeStorage::message_to_handle(&message, "test-cone");
578
579        // Join meta the way resolve_handle_impl does (after the fix)
580        let identifier = handle.meta.join(":");
581
582        // Verify the format matches what resolve_message_handle expects:
583        // "msg-{uuid}:{role}:{name}"
584        let parts: Vec<&str> = identifier.splitn(3, ':').collect();
585
586        assert_eq!(parts.len(), 3, "identifier should have 3 parts: {}", identifier);
587        assert!(parts[0].starts_with("msg-"), "first part should start with 'msg-': {}", parts[0]);
588
589        // The message ID should be extractable
590        let msg_part = parts[0];
591        let message_id_str = &msg_part[4..]; // Strip "msg-" prefix
592        let parsed_id = Uuid::parse_str(message_id_str);
593        assert!(parsed_id.is_ok(), "should be able to parse UUID from meta[0]");
594        assert_eq!(parsed_id.unwrap(), message.id);
595
596        // Role should be preserved
597        assert_eq!(parts[1], "user");
598
599        // Name should be preserved
600        assert_eq!(parts[2], "test-cone");
601    }
602
603    #[test]
604    fn invariant_handle_meta_roles() {
605        // Test all roles produce valid meta format
606        for (role, expected_str) in [
607            (MessageRole::User, "user"),
608            (MessageRole::Assistant, "assistant"),
609            (MessageRole::System, "system"),
610        ] {
611            let message = Message {
612                id: Uuid::new_v4(),
613                cone_id: Uuid::new_v4(),
614                role,
615                content: "test".to_string(),
616                created_at: 0,
617                model_id: None,
618                input_tokens: None,
619                output_tokens: None,
620            };
621
622            let handle = ConeStorage::message_to_handle(&message, "cone");
623            assert_eq!(handle.meta[1], expected_str);
624        }
625    }
626
627    #[test]
628    fn invariant_handle_plugin_method_fixed() {
629        // Handles from message_to_handle always use "cone" plugin and "chat" method
630        let message = Message {
631            id: Uuid::new_v4(),
632            cone_id: Uuid::new_v4(),
633            role: MessageRole::User,
634            content: "test".to_string(),
635            created_at: 0,
636            model_id: None,
637            input_tokens: None,
638            output_tokens: None,
639        };
640
641        let handle = ConeStorage::message_to_handle(&message, "any-name");
642
643        // plugin_id should match Cone's PLUGIN_ID
644        assert_eq!(handle.plugin_id, Cone::PLUGIN_ID);
645        assert_eq!(handle.version, "1.0.0");
646        assert_eq!(handle.method, "chat");
647    }
648
649    #[test]
650    fn invariant_handle_meta_has_three_parts() {
651        // All cone chat handles have exactly 3 meta parts
652        let message = Message {
653            id: Uuid::new_v4(),
654            cone_id: Uuid::new_v4(),
655            role: MessageRole::Assistant,
656            content: "response".to_string(),
657            created_at: 0,
658            model_id: Some("gpt-4".to_string()),
659            input_tokens: Some(10),
660            output_tokens: Some(20),
661        };
662
663        let handle = ConeStorage::message_to_handle(&message, "my-cone");
664
665        assert_eq!(handle.meta.len(), 3, "cone chat handle must have exactly 3 meta parts");
666    }
667}