1use futures::future;
4use llm::{chain::MultiChainStepMode, LLMProvider};
5use regex::Regex;
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9use super::{
10 error::{RunError, StoreError},
11 llm_bridge::LLMBackendRef,
12 store::PromptStore,
13 RunOutput,
14};
15
16#[derive(Clone)]
18enum PromptSource {
19 Stored(String),
21 Raw(String),
23}
24
25pub struct PromptRunner<'a> {
29 store: &'a PromptStore,
30 id_or_title: &'a str,
31 vars: HashMap<String, String>,
32 backend: Option<&'a dyn LLMProvider>,
33}
34
35impl<'a> PromptRunner<'a> {
36 pub(crate) fn new(store: &'a PromptStore, id_or_title: &'a str) -> Self {
38 Self {
39 store,
40 id_or_title,
41 vars: HashMap::new(),
42 backend: None,
43 }
44 }
45
46 pub fn vars(
48 mut self,
49 vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
50 ) -> Self {
51 self.vars = vars
52 .into_iter()
53 .map(|(k, v)| (k.into(), v.into()))
54 .collect();
55 self
56 }
57
58 pub fn backend(mut self, llm: &'a dyn LLMProvider) -> Self {
61 self.backend = Some(llm);
62 self
63 }
64
65 pub async fn run(self) -> Result<RunOutput, RunError> {
67 let pd = self.store.find_prompt(self.id_or_title)?;
68 let rendered = render_template(&pd.content, &self.vars);
69
70 let result = if let Some(llm) = self.backend {
71 use llm::chat::ChatMessage;
72 let req = ChatMessage::user().content(&rendered).build();
73 let resp = llm.chat(&[req]).await?;
74 resp.text().unwrap_or_default()
75 } else {
76 rendered
77 };
78
79 Ok(RunOutput::Prompt(result))
80 }
81}
82
83struct ChainStepDefinition<'a> {
87 pub output_key: String,
88 pub source: PromptSource,
89 pub provider_id: Option<String>,
90 pub mode: MultiChainStepMode,
91 pub condition: Option<Box<dyn Fn(&HashMap<String, String>) -> bool + Send + Sync + 'a>>,
92 pub fallback_source: Option<PromptSource>,
93}
94
95enum ExecutionNode<'a> {
97 Step(ChainStepDefinition<'a>),
99 Parallel(Vec<ChainStepDefinition<'a>>),
101}
102
103pub struct ParallelGroupBuilder<'a> {
105 steps: Vec<ChainStepDefinition<'a>>,
106}
107
108impl<'a> ParallelGroupBuilder<'a> {
109 fn new() -> Self {
110 Self { steps: Vec::new() }
111 }
112
113 pub fn step(mut self, output_key: &str, prompt_id_or_title: &str) -> Self {
115 self.steps.push(ChainStepDefinition {
116 output_key: output_key.to_string(),
117 source: PromptSource::Stored(prompt_id_or_title.to_string()),
118 provider_id: None,
119 mode: MultiChainStepMode::Completion,
120 condition: None,
121 fallback_source: None,
122 });
123 self
124 }
125
126 pub fn step_raw(mut self, output_key: &str, prompt_content: &str) -> Self {
128 self.steps.push(ChainStepDefinition {
129 output_key: output_key.to_string(),
130 source: PromptSource::Raw(prompt_content.to_string()),
131 provider_id: None,
132 mode: MultiChainStepMode::Completion,
133 condition: None,
134 fallback_source: None,
135 });
136 self
137 }
138
139 pub fn with_provider(mut self, provider_id: &str) -> Self {
141 if let Some(last_step) = self.steps.last_mut() {
142 last_step.provider_id = Some(provider_id.to_string());
143 }
144 self
145 }
146}
147
148pub struct ChainRunner<'a> {
150 store: &'a PromptStore,
151 backend: LLMBackendRef<'a>,
152 nodes: Vec<ExecutionNode<'a>>,
153 vars: HashMap<String, String>,
154}
155
156impl<'a> ChainRunner<'a> {
157 pub(crate) fn new(store: &'a PromptStore, backend: LLMBackendRef<'a>) -> Self {
159 Self {
160 store,
161 backend,
162 nodes: Vec::new(),
163 vars: HashMap::new(),
164 }
165 }
166
167 pub fn step(mut self, output_key: &str, prompt_id_or_title: &str) -> Self {
169 self.nodes.push(ExecutionNode::Step(ChainStepDefinition {
170 output_key: output_key.to_string(),
171 source: PromptSource::Stored(prompt_id_or_title.to_string()),
172 provider_id: None,
173 mode: MultiChainStepMode::Completion,
174 condition: None,
175 fallback_source: None,
176 }));
177 self
178 }
179
180 pub fn step_raw(mut self, output_key: &str, prompt_content: &str) -> Self {
182 self.nodes.push(ExecutionNode::Step(ChainStepDefinition {
183 output_key: output_key.to_string(),
184 source: PromptSource::Raw(prompt_content.to_string()),
185 provider_id: None,
186 mode: MultiChainStepMode::Completion,
187 condition: None,
188 fallback_source: None,
189 }));
190 self
191 }
192
193 pub fn step_if<F>(mut self, output_key: &str, prompt_id_or_title: &str, condition: F) -> Self
195 where
196 F: Fn(&HashMap<String, String>) -> bool + Send + Sync + 'a,
197 {
198 self.nodes.push(ExecutionNode::Step(ChainStepDefinition {
199 output_key: output_key.to_string(),
200 source: PromptSource::Stored(prompt_id_or_title.to_string()),
201 provider_id: None,
202 mode: MultiChainStepMode::Completion,
203 condition: Some(Box::new(condition)),
204 fallback_source: None,
205 }));
206 self
207 }
208
209 pub fn parallel<F>(mut self, build_group: F) -> Self
211 where
212 F: for<'b> FnOnce(ParallelGroupBuilder<'b>) -> ParallelGroupBuilder<'b>,
213 {
214 let group_builder = ParallelGroupBuilder::new();
215 let finished_group = build_group(group_builder);
216 self.nodes
217 .push(ExecutionNode::Parallel(finished_group.steps));
218 self
219 }
220
221 pub fn on_error_stored(mut self, fallback_id_or_title: &str) -> Self {
224 if let Some(node) = self.nodes.last_mut() {
225 if let ExecutionNode::Step(step_def) = node {
226 step_def.fallback_source =
227 Some(PromptSource::Stored(fallback_id_or_title.to_string()));
228 }
229 }
230 self
231 }
232
233 pub fn on_error_raw(mut self, fallback_content: &str) -> Self {
235 if let Some(node) = self.nodes.last_mut() {
236 if let ExecutionNode::Step(step_def) = node {
237 step_def.fallback_source = Some(PromptSource::Raw(fallback_content.to_string()));
238 }
239 }
240 self
241 }
242
243 pub fn with_provider(mut self, provider_id: &str) -> Self {
245 if let Some(node) = self.nodes.last_mut() {
246 match node {
247 ExecutionNode::Step(step) => {
248 step.provider_id = Some(provider_id.to_string());
249 }
250 ExecutionNode::Parallel(steps) => {
251 for step in steps {
252 if step.provider_id.is_none() {
253 step.provider_id = Some(provider_id.to_string());
254 }
255 }
256 }
257 }
258 }
259 self
260 }
261
262 pub fn with_mode(mut self, mode: MultiChainStepMode) -> Self {
264 if let Some(ExecutionNode::Step(step)) = self.nodes.last_mut() {
265 step.mode = mode;
266 }
267 self
268 }
269
270 pub fn vars(
272 mut self,
273 vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
274 ) -> Self {
275 self.vars = vars
276 .into_iter()
277 .map(|(k, v)| (k.into(), v.into()))
278 .collect();
279 self
280 }
281
282 pub async fn run(self) -> Result<RunOutput, RunError> {
284 let reg = match self.backend {
285 LLMBackendRef::Registry(reg) => reg,
286 _ => {
287 return Err(StoreError::Configuration(
288 "ChainRunner requires a LLMRegistry".to_string(),
289 )
290 .into())
291 }
292 };
293
294 let context = Arc::new(Mutex::new(self.vars.clone()));
295
296 for node in &self.nodes {
297 match node {
298 ExecutionNode::Step(step_def) => {
299 self.execute_step(step_def, Arc::clone(&context), reg)
300 .await?;
301 }
302 ExecutionNode::Parallel(steps) => {
303 let tasks = steps
304 .iter()
305 .map(|step| {
306 let context_clone = Arc::clone(&context);
307 self.execute_step(step, context_clone, reg)
308 })
309 .collect::<Vec<_>>();
310
311 future::try_join_all(tasks).await?;
312 }
313 }
314 }
315
316 let final_context = Arc::try_unwrap(context).ok().unwrap().into_inner().unwrap();
317 Ok(RunOutput::Chain(final_context))
318 }
319
320 async fn execute_step(
321 &self,
322 step_def: &ChainStepDefinition<'a>,
323 context: Arc<Mutex<HashMap<String, String>>>,
324 reg: &'a llm::chain::LLMRegistry,
325 ) -> Result<(), RunError> {
326 let should_run = {
327 let ctx = context.lock().unwrap();
328 step_def.condition.as_ref().map_or(true, |cond| cond(&ctx))
329 };
330 if !should_run {
331 return Ok(());
332 }
333
334 let result = self
335 .try_execute_source(&step_def.source, &context, step_def, reg)
336 .await;
337
338 let final_output = match (result, &step_def.fallback_source) {
339 (Ok(output), _) => Ok(output),
340 (Err(_), Some(fallback)) => {
341 self.try_execute_source(fallback, &context, step_def, reg)
342 .await
343 }
344 (Err(e), None) => Err(e),
345 }?;
346
347 let mut ctx = context.lock().unwrap();
348 ctx.insert(step_def.output_key.clone(), final_output);
349 Ok(())
350 }
351
352 async fn try_execute_source(
353 &self,
354 source: &PromptSource,
355 context: &Arc<Mutex<HashMap<String, String>>>,
356 step_def: &ChainStepDefinition<'a>,
357 reg: &'a llm::chain::LLMRegistry,
358 ) -> Result<String, RunError> {
359 let provider_id = step_def.provider_id.as_deref().ok_or_else(|| {
360 StoreError::Configuration(format!(
361 "Step '{}' is missing a provider ID.",
362 step_def.output_key
363 ))
364 })?;
365 let provider = reg.get(provider_id).ok_or_else(|| {
366 StoreError::Configuration(format!("Provider '{}' not found in registry", provider_id))
367 })?;
368
369 let prompt_content = match source {
370 PromptSource::Stored(id) => self.store.find_prompt(id)?.content,
371 PromptSource::Raw(content) => content.clone(),
372 };
373
374 let rendered = {
375 let ctx = context.lock().unwrap();
376 render_template(&prompt_content, &ctx)
377 };
378
379 use llm::chat::ChatMessage;
380 let req = ChatMessage::user().content(&rendered).build();
381 let resp = provider.chat(&[req]).await?;
382 Ok(resp.text().unwrap_or_default())
383 }
384}
385
386fn render_template(template: &str, vars: &HashMap<String, String>) -> String {
388 let re = Regex::new(r"\{\{\s*(\w+)\s*\}\}").unwrap();
389 re.replace_all(template, |caps: ®ex::Captures| {
390 let key = &caps[1];
391 vars.get(key).map(|s| s.as_str()).unwrap_or("").to_string()
392 })
393 .into_owned()
394}