1use crate::context::RunContext;
7use async_trait::async_trait;
8use std::future::Future;
9use std::marker::PhantomData;
10
11#[async_trait]
16pub trait InstructionFn<Deps>: Send + Sync {
17 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String>;
21}
22
23#[async_trait]
28pub trait SystemPromptFn<Deps>: Send + Sync {
29 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String>;
33}
34
35pub struct AsyncInstructionFn<F, Deps, Fut>
41where
42 F: Fn(&RunContext<Deps>) -> Fut + Send + Sync,
43 Fut: Future<Output = Option<String>> + Send,
44{
45 func: F,
46 _phantom: PhantomData<fn(Deps) -> Fut>,
47}
48
49impl<F, Deps, Fut> AsyncInstructionFn<F, Deps, Fut>
50where
51 F: Fn(&RunContext<Deps>) -> Fut + Send + Sync,
52 Fut: Future<Output = Option<String>> + Send,
53{
54 pub fn new(func: F) -> Self {
56 Self {
57 func,
58 _phantom: PhantomData,
59 }
60 }
61}
62
63#[async_trait]
64impl<F, Deps, Fut> InstructionFn<Deps> for AsyncInstructionFn<F, Deps, Fut>
65where
66 F: Fn(&RunContext<Deps>) -> Fut + Send + Sync,
67 Fut: Future<Output = Option<String>> + Send,
68 Deps: Send + Sync,
69{
70 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String> {
71 (self.func)(ctx).await
72 }
73}
74
75pub struct AsyncSystemPromptFn<F, Deps, Fut>
77where
78 F: Fn(&RunContext<Deps>) -> Fut + Send + Sync,
79 Fut: Future<Output = Option<String>> + Send,
80{
81 func: F,
82 _phantom: PhantomData<fn(Deps) -> Fut>,
83}
84
85impl<F, Deps, Fut> AsyncSystemPromptFn<F, Deps, Fut>
86where
87 F: Fn(&RunContext<Deps>) -> Fut + Send + Sync,
88 Fut: Future<Output = Option<String>> + Send,
89{
90 pub fn new(func: F) -> Self {
92 Self {
93 func,
94 _phantom: PhantomData,
95 }
96 }
97}
98
99#[async_trait]
100impl<F, Deps, Fut> SystemPromptFn<Deps> for AsyncSystemPromptFn<F, Deps, Fut>
101where
102 F: Fn(&RunContext<Deps>) -> Fut + Send + Sync,
103 Fut: Future<Output = Option<String>> + Send,
104 Deps: Send + Sync,
105{
106 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String> {
107 (self.func)(ctx).await
108 }
109}
110
111pub struct SyncInstructionFn<F, Deps>
117where
118 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync,
119{
120 func: F,
121 _phantom: PhantomData<Deps>,
122}
123
124impl<F, Deps> SyncInstructionFn<F, Deps>
125where
126 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync,
127{
128 pub fn new(func: F) -> Self {
130 Self {
131 func,
132 _phantom: PhantomData,
133 }
134 }
135}
136
137#[async_trait]
138impl<F, Deps> InstructionFn<Deps> for SyncInstructionFn<F, Deps>
139where
140 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync,
141 Deps: Send + Sync,
142{
143 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String> {
144 (self.func)(ctx)
145 }
146}
147
148pub struct SyncSystemPromptFn<F, Deps>
150where
151 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync,
152{
153 func: F,
154 _phantom: PhantomData<Deps>,
155}
156
157impl<F, Deps> SyncSystemPromptFn<F, Deps>
158where
159 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync,
160{
161 pub fn new(func: F) -> Self {
163 Self {
164 func,
165 _phantom: PhantomData,
166 }
167 }
168}
169
170#[async_trait]
171impl<F, Deps> SystemPromptFn<Deps> for SyncSystemPromptFn<F, Deps>
172where
173 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync,
174 Deps: Send + Sync,
175{
176 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String> {
177 (self.func)(ctx)
178 }
179}
180
181pub struct StaticInstruction {
187 text: String,
188}
189
190impl StaticInstruction {
191 pub fn new(text: impl Into<String>) -> Self {
193 Self { text: text.into() }
194 }
195}
196
197#[async_trait]
198impl<Deps: Send + Sync> InstructionFn<Deps> for StaticInstruction {
199 async fn generate(&self, _ctx: &RunContext<Deps>) -> Option<String> {
200 Some(self.text.clone())
201 }
202}
203
204pub struct StaticSystemPrompt {
206 text: String,
207}
208
209impl StaticSystemPrompt {
210 pub fn new(text: impl Into<String>) -> Self {
212 Self { text: text.into() }
213 }
214}
215
216#[async_trait]
217impl<Deps: Send + Sync> SystemPromptFn<Deps> for StaticSystemPrompt {
218 async fn generate(&self, _ctx: &RunContext<Deps>) -> Option<String> {
219 Some(self.text.clone())
220 }
221}
222
223pub struct InstructionBuilder<Deps> {
229 parts: Vec<Box<dyn InstructionFn<Deps>>>,
230 separator: String,
231}
232
233impl<Deps: Send + Sync + 'static> InstructionBuilder<Deps> {
234 pub fn new() -> Self {
236 Self {
237 parts: Vec::new(),
238 separator: "\n\n".to_string(),
239 }
240 }
241
242 pub fn separator(mut self, sep: impl Into<String>) -> Self {
244 self.separator = sep.into();
245 self
246 }
247
248 #[allow(clippy::should_implement_trait)]
250 pub fn add(mut self, text: impl Into<String>) -> Self {
251 self.parts.push(Box::new(StaticInstruction::new(text)));
252 self
253 }
254
255 pub fn add_fn<F>(mut self, func: F) -> Self
257 where
258 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync + 'static,
259 {
260 self.parts.push(Box::new(SyncInstructionFn::new(func)));
261 self
262 }
263
264 pub fn add_instruction(mut self, instruction: Box<dyn InstructionFn<Deps>>) -> Self {
266 self.parts.push(instruction);
267 self
268 }
269
270 pub fn build(self) -> CombinedInstruction<Deps> {
272 CombinedInstruction {
273 parts: self.parts,
274 separator: self.separator,
275 }
276 }
277}
278
279impl<Deps: Send + Sync + 'static> Default for InstructionBuilder<Deps> {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285pub struct CombinedInstruction<Deps> {
287 parts: Vec<Box<dyn InstructionFn<Deps>>>,
288 separator: String,
289}
290
291#[async_trait]
292impl<Deps: Send + Sync> InstructionFn<Deps> for CombinedInstruction<Deps> {
293 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String> {
294 let mut results = Vec::new();
295
296 for part in &self.parts {
297 if let Some(text) = part.generate(ctx).await {
298 if !text.is_empty() {
299 results.push(text);
300 }
301 }
302 }
303
304 if results.is_empty() {
305 None
306 } else {
307 Some(results.join(&self.separator))
308 }
309 }
310}
311
312pub struct DateTimeInstruction {
318 format: String,
319 prefix: String,
320}
321
322impl DateTimeInstruction {
323 pub fn new() -> Self {
325 Self {
326 format: "%Y-%m-%d %H:%M:%S UTC".to_string(),
327 prefix: "Current date and time:".to_string(),
328 }
329 }
330
331 pub fn format(mut self, fmt: impl Into<String>) -> Self {
333 self.format = fmt.into();
334 self
335 }
336
337 pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
339 self.prefix = prefix.into();
340 self
341 }
342}
343
344impl Default for DateTimeInstruction {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350#[async_trait]
351impl<Deps: Send + Sync> InstructionFn<Deps> for DateTimeInstruction {
352 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String> {
353 let formatted = ctx.start_time.format(&self.format).to_string();
354 Some(format!("{} {}", self.prefix, formatted))
355 }
356}
357
358pub struct UserInfoInstruction<F, Deps>
360where
361 F: Fn(&Deps) -> Option<String> + Send + Sync,
362{
363 extractor: F,
364 _phantom: PhantomData<Deps>,
365}
366
367impl<F, Deps> UserInfoInstruction<F, Deps>
368where
369 F: Fn(&Deps) -> Option<String> + Send + Sync,
370{
371 pub fn new(extractor: F) -> Self {
373 Self {
374 extractor,
375 _phantom: PhantomData,
376 }
377 }
378}
379
380#[async_trait]
381impl<F, Deps> InstructionFn<Deps> for UserInfoInstruction<F, Deps>
382where
383 F: Fn(&Deps) -> Option<String> + Send + Sync,
384 Deps: Send + Sync,
385{
386 async fn generate(&self, ctx: &RunContext<Deps>) -> Option<String> {
387 (self.extractor)(&ctx.deps)
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use chrono::Utc;
395 use std::sync::Arc;
396
397 fn make_test_context() -> RunContext<()> {
398 RunContext {
399 deps: Arc::new(()),
400 run_id: "test-run".to_string(),
401 start_time: Utc::now(),
402 model_name: "test-model".to_string(),
403 model_settings: Default::default(),
404 tool_name: None,
405 tool_call_id: None,
406 retry_count: 0,
407 metadata: None,
408 }
409 }
410
411 #[tokio::test]
412 async fn test_static_instruction() {
413 let instruction = StaticInstruction::new("Be helpful.");
414 let ctx = make_test_context();
415 let result = instruction.generate(&ctx).await;
416 assert_eq!(result, Some("Be helpful.".to_string()));
417 }
418
419 #[tokio::test]
420 async fn test_sync_instruction_fn() {
421 let instruction =
422 SyncInstructionFn::new(|ctx: &RunContext<()>| Some(format!("Run ID: {}", ctx.run_id)));
423 let ctx = make_test_context();
424 let result = instruction.generate(&ctx).await;
425 assert_eq!(result, Some("Run ID: test-run".to_string()));
426 }
427
428 #[tokio::test]
429 async fn test_instruction_builder() {
430 let instruction = InstructionBuilder::<()>::new()
431 .add("First instruction.")
432 .add("Second instruction.")
433 .build();
434
435 let ctx = make_test_context();
436 let result = instruction.generate(&ctx).await.unwrap();
437
438 assert!(result.contains("First instruction."));
439 assert!(result.contains("Second instruction."));
440 }
441
442 #[tokio::test]
443 async fn test_datetime_instruction() {
444 let instruction = DateTimeInstruction::new();
445 let ctx = make_test_context();
446 let result = instruction.generate(&ctx).await.unwrap();
447
448 assert!(result.contains("Current date and time:"));
449 }
450
451 #[tokio::test]
452 async fn test_combined_instruction_skips_empty() {
453 let instruction = InstructionBuilder::<()>::new()
454 .add("Has content.")
455 .add_fn(|_| None) .add("") .add("Also has content.")
458 .build();
459
460 let ctx = make_test_context();
461 let result = instruction.generate(&ctx).await.unwrap();
462
463 let parts: Vec<_> = result.split("\n\n").collect();
464 assert_eq!(parts.len(), 2);
465 }
466}