prompt_store/api/
runner.rs

1//! Fluent runners for executing single prompts or complex chains.
2
3use futures::future;
4use llm::{chain::MultiChainStepMode, LLMProvider};
5use regex::Regex;
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9use super::{
10    error::{RunError, StoreError},
11    llm_bridge::LLMBackendRef,
12    store::PromptStore,
13    RunOutput,
14};
15
16/// Represents the source of a prompt for a chain step.
17#[derive(Clone)]
18enum PromptSource {
19    /// Load the prompt from the store using its ID or title.
20    Stored(String),
21    /// Use a raw, in-memory string as the prompt template.
22    Raw(String),
23}
24
25// --- PromptRunner for single prompts ---
26
27/// A fluent builder to configure and execute a single stored prompt.
28pub struct PromptRunner<'a> {
29    store: &'a PromptStore,
30    id_or_title: &'a str,
31    vars: HashMap<String, String>,
32    backend: Option<&'a dyn LLMProvider>,
33}
34
35impl<'a> PromptRunner<'a> {
36    /// Creates a new `PromptRunner`.
37    pub(crate) fn new(store: &'a PromptStore, id_or_title: &'a str) -> Self {
38        Self {
39            store,
40            id_or_title,
41            vars: HashMap::new(),
42            backend: None,
43        }
44    }
45
46    /// Sets the variables for template substitution in the prompt.
47    pub fn vars(
48        mut self,
49        vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
50    ) -> Self {
51        self.vars = vars
52            .into_iter()
53            .map(|(k, v)| (k.into(), v.into()))
54            .collect();
55        self
56    }
57
58    /// Sets the LLM backend to execute the prompt with.
59    /// If not set, `run()` will only perform template substitution and return the result.
60    pub fn backend(mut self, llm: &'a dyn LLMProvider) -> Self {
61        self.backend = Some(llm);
62        self
63    }
64
65    /// Finds, decrypts, renders, and executes the prompt.
66    pub async fn run(self) -> Result<RunOutput, RunError> {
67        let pd = self.store.find_prompt(self.id_or_title)?;
68        let rendered = render_template(&pd.content, &self.vars);
69
70        let result = if let Some(llm) = self.backend {
71            use llm::chat::ChatMessage;
72            let req = ChatMessage::user().content(&rendered).build();
73            let resp = llm.chat(&[req]).await?;
74            resp.text().unwrap_or_default()
75        } else {
76            rendered
77        };
78
79        Ok(RunOutput::Prompt(result))
80    }
81}
82
83// --- ChainRunner for multi-step chains ---
84
85/// Defines a single step in a chain.
86struct ChainStepDefinition<'a> {
87    pub output_key: String,
88    pub source: PromptSource,
89    pub provider_id: Option<String>,
90    pub mode: MultiChainStepMode,
91    pub condition: Option<Box<dyn Fn(&HashMap<String, String>) -> bool + Send + Sync + 'a>>,
92    pub fallback_source: Option<PromptSource>,
93}
94
95/// Represents a node in the execution graph of a chain.
96enum ExecutionNode<'a> {
97    /// A single, sequential step.
98    Step(ChainStepDefinition<'a>),
99    /// A group of steps to be executed in parallel.
100    Parallel(Vec<ChainStepDefinition<'a>>),
101}
102
103/// A builder for defining a group of parallel steps.
104pub struct ParallelGroupBuilder<'a> {
105    steps: Vec<ChainStepDefinition<'a>>,
106}
107
108impl<'a> ParallelGroupBuilder<'a> {
109    fn new() -> Self {
110        Self { steps: Vec::new() }
111    }
112
113    /// Adds a step from the store to the parallel group.
114    pub fn step(mut self, output_key: &str, prompt_id_or_title: &str) -> Self {
115        self.steps.push(ChainStepDefinition {
116            output_key: output_key.to_string(),
117            source: PromptSource::Stored(prompt_id_or_title.to_string()),
118            provider_id: None,
119            mode: MultiChainStepMode::Completion,
120            condition: None,
121            fallback_source: None,
122        });
123        self
124    }
125
126    /// Adds a raw prompt step to the parallel group.
127    pub fn step_raw(mut self, output_key: &str, prompt_content: &str) -> Self {
128        self.steps.push(ChainStepDefinition {
129            output_key: output_key.to_string(),
130            source: PromptSource::Raw(prompt_content.to_string()),
131            provider_id: None,
132            mode: MultiChainStepMode::Completion,
133            condition: None,
134            fallback_source: None,
135        });
136        self
137    }
138
139    /// Sets the provider for the last added step in the parallel group.
140    pub fn with_provider(mut self, provider_id: &str) -> Self {
141        if let Some(last_step) = self.steps.last_mut() {
142            last_step.provider_id = Some(provider_id.to_string());
143        }
144        self
145    }
146}
147
148/// A fluent builder to define and execute a multi-step prompt chain.
149pub struct ChainRunner<'a> {
150    store: &'a PromptStore,
151    backend: LLMBackendRef<'a>,
152    nodes: Vec<ExecutionNode<'a>>,
153    vars: HashMap<String, String>,
154}
155
156impl<'a> ChainRunner<'a> {
157    /// Creates a new `ChainRunner`.
158    pub(crate) fn new(store: &'a PromptStore, backend: LLMBackendRef<'a>) -> Self {
159        Self {
160            store,
161            backend,
162            nodes: Vec::new(),
163            vars: HashMap::new(),
164        }
165    }
166
167    /// Adds a sequential step from the store.
168    pub fn step(mut self, output_key: &str, prompt_id_or_title: &str) -> Self {
169        self.nodes.push(ExecutionNode::Step(ChainStepDefinition {
170            output_key: output_key.to_string(),
171            source: PromptSource::Stored(prompt_id_or_title.to_string()),
172            provider_id: None,
173            mode: MultiChainStepMode::Completion,
174            condition: None,
175            fallback_source: None,
176        }));
177        self
178    }
179
180    /// Adds a sequential step with a raw prompt.
181    pub fn step_raw(mut self, output_key: &str, prompt_content: &str) -> Self {
182        self.nodes.push(ExecutionNode::Step(ChainStepDefinition {
183            output_key: output_key.to_string(),
184            source: PromptSource::Raw(prompt_content.to_string()),
185            provider_id: None,
186            mode: MultiChainStepMode::Completion,
187            condition: None,
188            fallback_source: None,
189        }));
190        self
191    }
192
193    /// Adds a conditional step from the store. It runs only if the condition is met.
194    pub fn step_if<F>(mut self, output_key: &str, prompt_id_or_title: &str, condition: F) -> Self
195    where
196        F: Fn(&HashMap<String, String>) -> bool + Send + Sync + 'a,
197    {
198        self.nodes.push(ExecutionNode::Step(ChainStepDefinition {
199            output_key: output_key.to_string(),
200            source: PromptSource::Stored(prompt_id_or_title.to_string()),
201            provider_id: None,
202            mode: MultiChainStepMode::Completion,
203            condition: Some(Box::new(condition)),
204            fallback_source: None,
205        }));
206        self
207    }
208
209    /// Adds a group of steps that will be executed in parallel.
210    pub fn parallel<F>(mut self, build_group: F) -> Self
211    where
212        F: for<'b> FnOnce(ParallelGroupBuilder<'b>) -> ParallelGroupBuilder<'b>,
213    {
214        let group_builder = ParallelGroupBuilder::new();
215        let finished_group = build_group(group_builder);
216        self.nodes
217            .push(ExecutionNode::Parallel(finished_group.steps));
218        self
219    }
220
221    /// Sets a fallback prompt from the store for the last added step.
222    /// This is executed if the primary prompt execution fails.
223    pub fn on_error_stored(mut self, fallback_id_or_title: &str) -> Self {
224        if let Some(node) = self.nodes.last_mut() {
225            if let ExecutionNode::Step(step_def) = node {
226                step_def.fallback_source =
227                    Some(PromptSource::Stored(fallback_id_or_title.to_string()));
228            }
229        }
230        self
231    }
232
233    /// Sets a raw fallback prompt for the last added step.
234    pub fn on_error_raw(mut self, fallback_content: &str) -> Self {
235        if let Some(node) = self.nodes.last_mut() {
236            if let ExecutionNode::Step(step_def) = node {
237                step_def.fallback_source = Some(PromptSource::Raw(fallback_content.to_string()));
238            }
239        }
240        self
241    }
242
243    /// Specifies the provider for the last added step or all steps in the last parallel group.
244    pub fn with_provider(mut self, provider_id: &str) -> Self {
245        if let Some(node) = self.nodes.last_mut() {
246            match node {
247                ExecutionNode::Step(step) => {
248                    step.provider_id = Some(provider_id.to_string());
249                }
250                ExecutionNode::Parallel(steps) => {
251                    for step in steps {
252                        if step.provider_id.is_none() {
253                            step.provider_id = Some(provider_id.to_string());
254                        }
255                    }
256                }
257            }
258        }
259        self
260    }
261
262    /// Sets the execution mode for the last added step.
263    pub fn with_mode(mut self, mode: MultiChainStepMode) -> Self {
264        if let Some(ExecutionNode::Step(step)) = self.nodes.last_mut() {
265            step.mode = mode;
266        }
267        self
268    }
269
270    /// Sets initial variables for the chain.
271    pub fn vars(
272        mut self,
273        vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
274    ) -> Self {
275        self.vars = vars
276            .into_iter()
277            .map(|(k, v)| (k.into(), v.into()))
278            .collect();
279        self
280    }
281
282    /// Executes the chain.
283    pub async fn run(self) -> Result<RunOutput, RunError> {
284        let reg = match self.backend {
285            LLMBackendRef::Registry(reg) => reg,
286            _ => {
287                return Err(StoreError::Configuration(
288                    "ChainRunner requires a LLMRegistry".to_string(),
289                )
290                .into())
291            }
292        };
293
294        let context = Arc::new(Mutex::new(self.vars.clone()));
295
296        for node in &self.nodes {
297            match node {
298                ExecutionNode::Step(step_def) => {
299                    self.execute_step(step_def, Arc::clone(&context), reg)
300                        .await?;
301                }
302                ExecutionNode::Parallel(steps) => {
303                    let tasks = steps
304                        .iter()
305                        .map(|step| {
306                            let context_clone = Arc::clone(&context);
307                            self.execute_step(step, context_clone, reg)
308                        })
309                        .collect::<Vec<_>>();
310
311                    future::try_join_all(tasks).await?;
312                }
313            }
314        }
315
316        let final_context = Arc::try_unwrap(context).ok().unwrap().into_inner().unwrap();
317        Ok(RunOutput::Chain(final_context))
318    }
319
320    async fn execute_step(
321        &self,
322        step_def: &ChainStepDefinition<'a>,
323        context: Arc<Mutex<HashMap<String, String>>>,
324        reg: &'a llm::chain::LLMRegistry,
325    ) -> Result<(), RunError> {
326        let should_run = {
327            let ctx = context.lock().unwrap();
328            step_def.condition.as_ref().map_or(true, |cond| cond(&ctx))
329        };
330        if !should_run {
331            return Ok(());
332        }
333
334        let result = self
335            .try_execute_source(&step_def.source, &context, step_def, reg)
336            .await;
337
338        let final_output = match (result, &step_def.fallback_source) {
339            (Ok(output), _) => Ok(output),
340            (Err(_), Some(fallback)) => {
341                self.try_execute_source(fallback, &context, step_def, reg)
342                    .await
343            }
344            (Err(e), None) => Err(e),
345        }?;
346
347        let mut ctx = context.lock().unwrap();
348        ctx.insert(step_def.output_key.clone(), final_output);
349        Ok(())
350    }
351
352    async fn try_execute_source(
353        &self,
354        source: &PromptSource,
355        context: &Arc<Mutex<HashMap<String, String>>>,
356        step_def: &ChainStepDefinition<'a>,
357        reg: &'a llm::chain::LLMRegistry,
358    ) -> Result<String, RunError> {
359        let provider_id = step_def.provider_id.as_deref().ok_or_else(|| {
360            StoreError::Configuration(format!(
361                "Step '{}' is missing a provider ID.",
362                step_def.output_key
363            ))
364        })?;
365        let provider = reg.get(provider_id).ok_or_else(|| {
366            StoreError::Configuration(format!("Provider '{}' not found in registry", provider_id))
367        })?;
368
369        let prompt_content = match source {
370            PromptSource::Stored(id) => self.store.find_prompt(id)?.content,
371            PromptSource::Raw(content) => content.clone(),
372        };
373
374        let rendered = {
375            let ctx = context.lock().unwrap();
376            render_template(&prompt_content, &ctx)
377        };
378
379        use llm::chat::ChatMessage;
380        let req = ChatMessage::user().content(&rendered).build();
381        let resp = provider.chat(&[req]).await?;
382        Ok(resp.text().unwrap_or_default())
383    }
384}
385
386/// Renders a template string with the given variables.
387fn render_template(template: &str, vars: &HashMap<String, String>) -> String {
388    let re = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
389    re.replace_all(template, |caps: &regex::Captures| {
390        let key = &caps[1];
391        vars.get(key).map(|s| s.as_str()).unwrap_or("").to_string()
392    })
393    .into_owned()
394}