Skip to main content

batuta/serve/banco/
prompts.rs

1//! System prompt presets for Banco.
2//!
3//! Save named system prompts and reference them in chat via `@preset:name`.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::RwLock;
8
9/// A saved system prompt preset.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PromptPreset {
12    pub id: String,
13    pub name: String,
14    pub content: String,
15    pub created: u64,
16}
17
18/// In-memory preset store.
19pub struct PromptStore {
20    presets: RwLock<HashMap<String, PromptPreset>>,
21    counter: std::sync::atomic::AtomicU64,
22}
23
24impl PromptStore {
25    #[must_use]
26    pub fn new() -> Self {
27        let store = Self {
28            presets: RwLock::new(HashMap::new()),
29            counter: std::sync::atomic::AtomicU64::new(0),
30        };
31        // Seed with built-in presets
32        store.seed_defaults();
33        store
34    }
35
36    fn seed_defaults(&self) {
37        let defaults = [
38            ("coding", "Coding Assistant", "You are an expert software engineer. Write clean, tested, idiomatic code. Explain your reasoning."),
39            ("concise", "Concise", "You are a helpful assistant. Be concise and direct. No filler."),
40            ("tutor", "Tutor", "You are a patient tutor. Explain concepts step by step. Ask the student questions to check understanding."),
41        ];
42        for (id, name, content) in defaults {
43            if let Ok(mut store) = self.presets.write() {
44                store.insert(
45                    id.to_string(),
46                    PromptPreset {
47                        id: id.to_string(),
48                        name: name.to_string(),
49                        content: content.to_string(),
50                        created: 0,
51                    },
52                );
53            }
54        }
55    }
56
57    /// Create or update a preset.
58    pub fn save(&self, id: &str, name: &str, content: &str) -> PromptPreset {
59        let preset = PromptPreset {
60            id: id.to_string(),
61            name: name.to_string(),
62            content: content.to_string(),
63            created: epoch_secs(),
64        };
65        if let Ok(mut store) = self.presets.write() {
66            store.insert(id.to_string(), preset.clone());
67        }
68        preset
69    }
70
71    /// Create a preset with auto-generated ID.
72    pub fn create(&self, name: &str, content: &str) -> PromptPreset {
73        let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
74        let id = format!("preset-{seq}");
75        self.save(&id, name, content)
76    }
77
78    /// Get a preset by ID.
79    #[must_use]
80    pub fn get(&self, id: &str) -> Option<PromptPreset> {
81        self.presets.read().ok()?.get(id).cloned()
82    }
83
84    /// List all presets.
85    #[must_use]
86    pub fn list(&self) -> Vec<PromptPreset> {
87        self.presets
88            .read()
89            .map(|s| {
90                let mut v: Vec<_> = s.values().cloned().collect();
91                v.sort_by(|a, b| a.id.cmp(&b.id));
92                v
93            })
94            .unwrap_or_default()
95    }
96
97    /// Delete a preset by ID. Returns true if it existed.
98    pub fn delete(&self, id: &str) -> bool {
99        self.presets.write().map(|mut s| s.remove(id).is_some()).unwrap_or(false)
100    }
101
102    /// Expand `@preset:id` references in a message content string.
103    /// Returns the expanded content, or the original if no preset found.
104    #[must_use]
105    pub fn expand(&self, content: &str) -> String {
106        if let Some(preset_id) = content.strip_prefix("@preset:") {
107            let preset_id = preset_id.trim();
108            if let Some(preset) = self.get(preset_id) {
109                return preset.content;
110            }
111        }
112        content.to_string()
113    }
114}
115
116impl Default for PromptStore {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122fn epoch_secs() -> u64 {
123    std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
124}