alith_client/workflows/
basic_primitive.rs1use crate::{
2 components::{
3 InstructPromptTrait,
4 cascade::{CascadeFlow, step::StepConfig},
5 instruct_prompt::InstructPrompt,
6 },
7 primitives::*,
8};
9use alith_interface::{
10 llms::LLMBackend,
11 requests::{
12 completion::CompletionRequest,
13 req_components::{RequestConfig, RequestConfigTrait},
14 },
15};
16use std::sync::Arc;
17
18pub struct BasicPrimitiveWorkflow<P> {
19 pub primitive: P,
20 pub base_req: CompletionRequest,
21 pub result_can_be_none: bool,
22 pub instruct_prompt: InstructPrompt,
23}
24
25impl<P: PrimitiveTrait> BasicPrimitiveWorkflow<P> {
26 pub fn new(backend: Arc<LLMBackend>) -> Self {
27 Self {
28 primitive: P::default(),
29 base_req: CompletionRequest::new(backend),
30 result_can_be_none: false,
31 instruct_prompt: InstructPrompt::default(),
32 }
33 }
34
35 pub async fn return_primitive(&mut self) -> crate::Result<P::PrimitiveResult> {
36 self.result_can_be_none = false;
37 let res = self.return_result().await?;
38 if let Some(primitive_result) = res.primitive_result {
39 Ok(self.primitive.parse_to_primitive(&primitive_result)?)
40 } else {
41 Err(anyhow::format_err!("No result returned."))
42 }
43 }
44
45 pub async fn return_optional_primitive(&mut self) -> crate::Result<Option<P::PrimitiveResult>> {
46 self.result_can_be_none = true;
47 let res = self.return_result().await?;
48 if let Some(primitive_result) = res.primitive_result {
49 Ok(Some(self.primitive.parse_to_primitive(&primitive_result)?))
50 } else {
51 Ok(None)
52 }
53 }
54
55 pub async fn return_result(&mut self) -> crate::Result<BasicPrimitiveResult> {
56 self.result_can_be_none = false;
57 let mut flow = self.basic_primitive()?;
58 flow.run_all_rounds(&mut self.base_req).await?;
59 BasicPrimitiveResult::new(flow)
60 }
61
62 pub async fn return_optional_result(&mut self) -> crate::Result<BasicPrimitiveResult> {
63 self.result_can_be_none = true;
64 let mut flow = self.basic_primitive()?;
65 flow.run_all_rounds(&mut self.base_req).await?;
66 BasicPrimitiveResult::new(flow)
67 }
68
69 fn basic_primitive(&mut self) -> crate::Result<CascadeFlow> {
70 let mut flow = CascadeFlow::new("BasicPrimitive");
71 let task = self.instruct_prompt.build_instruct_prompt(false)?;
72
73 let step_config = StepConfig {
74 step_prefix: Some(format!(
75 "Generating {}:\n",
76 self.primitive.solution_description(self.result_can_be_none),
77 )),
78 stop_word_no_result: self
79 .primitive
80 .stop_word_result_is_none(self.result_can_be_none),
81 grammar: self.primitive.grammar(),
82 ..StepConfig::default()
83 };
84
85 flow.new_round(task).add_inference_step(&step_config);
86
87 Ok(flow)
88 }
89}
90
91impl<P: PrimitiveTrait> RequestConfigTrait for BasicPrimitiveWorkflow<P> {
92 fn config(&mut self) -> &mut RequestConfig {
93 &mut self.base_req.config
94 }
95
96 fn reset_request(&mut self) {
97 self.instruct_prompt.reset_instruct_prompt();
98 self.base_req.reset_completion_request();
99 }
100}
101
102impl<P: PrimitiveTrait> InstructPromptTrait for BasicPrimitiveWorkflow<P> {
103 fn instruct_prompt_mut(&mut self) -> &mut InstructPrompt {
104 &mut self.instruct_prompt
105 }
106}
107
108pub struct BasicPrimitiveWorkflowBuilder {
109 pub base_req: CompletionRequest,
110}
111
112impl BasicPrimitiveWorkflowBuilder {
113 pub fn new(backend: Arc<LLMBackend>) -> Self {
114 Self {
115 base_req: CompletionRequest::new(backend),
116 }
117 }
118
119 fn build<P: PrimitiveTrait>(self) -> BasicPrimitiveWorkflow<P> {
120 BasicPrimitiveWorkflow {
121 primitive: P::default(),
122 base_req: self.base_req,
123 result_can_be_none: false,
124 instruct_prompt: InstructPrompt::default(),
125 }
126 }
127}
128
129macro_rules! basic_primitive_workflow_primitive_impl {
130 ($($name:ident => $type:ty),*) => {
131 impl BasicPrimitiveWorkflowBuilder {
132 $(
133 pub fn $name(self) -> BasicPrimitiveWorkflow<$type> {
134 self.build()
135 }
136 )*
137 }
138 }
139}
140
141basic_primitive_workflow_primitive_impl! {
142 boolean => BooleanPrimitive,
143 integer => IntegerPrimitive,
144 sentences => SentencesPrimitive,
145 words => WordsPrimitive,
146 exact_string => ExactStringPrimitive,
147 text_list => TextListPrimitive
148}
149
150#[derive(Clone)]
151pub struct BasicPrimitiveResult {
152 pub primitive_result: Option<String>,
153 pub duration: std::time::Duration,
154 pub workflow: CascadeFlow,
155}
156
157impl BasicPrimitiveResult {
158 pub fn new(flow: CascadeFlow) -> crate::Result<Self> {
159 let reason_result = BasicPrimitiveResult {
160 primitive_result: flow.primitive_result(),
161 duration: flow.duration,
162 workflow: flow,
163 };
164 Ok(reason_result)
165 }
166}