Skip to main content

prompt_store/core/
storage.rs

1use super::utils::ensure_dir;
2use aes_gcm::{
3    aead::{Aead, KeyInit},
4    Aes256Gcm, Key, Nonce,
5};
6use base64::{engine::general_purpose, Engine as _};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::env;
10use std::fs;
11use std::path::{Path, PathBuf};
12
13use super::crypto::load_or_generate_key;
14
15/// Data for a single, storable prompt, including an optional I/O schema.
16#[derive(Serialize, Deserialize, Clone, Debug)]
17pub struct PromptData {
18    pub id: String,
19    pub title: String,
20    pub content: String,
21    pub tags: Vec<String>,
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub schema: Option<PromptSchema>,
24}
25
26/// Defines the expected inputs and output format (as a JSON Schema value) for a prompt.
27#[derive(Serialize, Deserialize, Clone, Debug)]
28pub struct PromptSchema {
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub inputs: Option<Value>,
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub output: Option<Value>,
33}
34
35/// Metadata for a prompt chain.
36#[derive(Serialize, Deserialize)]
37pub struct ChainData {
38    pub id: String,
39    pub title: String,
40}
41
42/// Runtime context holding paths and encryption keys.
43pub struct AppCtx {
44    pub base_dir: PathBuf,
45    pub workspaces_dir: PathBuf,
46    pub registries_dir: PathBuf,
47    pub key_path: PathBuf,
48    pub cipher: Aes256Gcm,
49}
50
51/// Parses a prompt identifier into its workspace and local ID components.
52/// If no workspace is specified (i.e., no `::`), it defaults to the "default" workspace.
53pub fn parse_id(id: &str) -> (String, String) {
54    match id.split_once("::") {
55        Some((workspace, prompt_id)) => (workspace.to_string(), prompt_id.to_string()),
56        None => ("default".to_string(), id.to_string()),
57    }
58}
59
60impl AppCtx {
61    /// Initializes the application context, creating necessary directories and loading the encryption key.
62    pub fn init() -> Result<Self, String> {
63        let home =
64            env::var("HOME").map_err(|_| "Unable to determine HOME directory".to_string())?;
65        let base_dir = PathBuf::from(home).join(".prompt-store");
66        let key_dir = base_dir.join("keys");
67        let key_path = key_dir.join("key.bin");
68        let workspaces_dir = base_dir.join("workspaces");
69        let registries_dir = base_dir.join("registries");
70
71        ensure_dir(&base_dir)?;
72        ensure_dir(&key_dir)?;
73        ensure_dir(&workspaces_dir)?;
74        ensure_dir(&workspaces_dir.join("default"))?; // Ensure default workspace exists
75        ensure_dir(&registries_dir)?;
76
77        let (key_bytes, _) = load_or_generate_key(&key_path)?;
78        let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&key_bytes));
79
80        Ok(Self {
81            base_dir,
82            workspaces_dir,
83            registries_dir,
84            key_path,
85            cipher,
86        })
87    }
88
89    /// Constructs the full path for a prompt file from its full ID.
90    pub fn prompt_path(&self, full_id: &str) -> PathBuf {
91        let (workspace, local_id) = parse_id(full_id);
92        let workspace_path = self.workspaces_dir.join(workspace);
93
94        if let Some((chain_id, step_id)) = local_id.split_once('/') {
95            workspace_path
96                .join(chain_id)
97                .join(format!("{}.prompt", step_id))
98        } else {
99            workspace_path.join(format!("{}.prompt", local_id))
100        }
101    }
102}
103
104/// Decrypts a prompt file to read its full data.
105pub fn decrypt_full_prompt(path: &Path, cipher: &Aes256Gcm) -> Result<PromptData, String> {
106    let encoded = fs::read_to_string(path).map_err(|e| format!("Read error: {}", e))?;
107    let decoded = general_purpose::STANDARD
108        .decode(encoded.trim_end())
109        .map_err(|_| "Corrupted data".to_string())?;
110    if decoded.len() < 12 {
111        return Err("Corrupted data".to_string());
112    }
113    let (nonce_bytes, cipher_bytes) = decoded.split_at(12);
114    let plaintext = cipher
115        .decrypt(Nonce::from_slice(nonce_bytes), cipher_bytes)
116        .map_err(|_| "Decrypt error".to_string())?;
117    serde_json::from_slice(&plaintext).map_err(|_| "Invalid JSON".to_string())
118}