alith_client/workflows/
basic_primitive.rs

1use crate::{
2    components::{
3        InstructPromptTrait,
4        cascade::{CascadeFlow, step::StepConfig},
5        instruct_prompt::InstructPrompt,
6    },
7    primitives::*,
8};
9use alith_interface::{
10    llms::LLMBackend,
11    requests::{
12        completion::CompletionRequest,
13        req_components::{RequestConfig, RequestConfigTrait},
14    },
15};
16use std::sync::Arc;
17
18pub struct BasicPrimitiveWorkflow<P> {
19    pub primitive: P,
20    pub base_req: CompletionRequest,
21    pub result_can_be_none: bool,
22    pub instruct_prompt: InstructPrompt,
23}
24
25impl<P: PrimitiveTrait> BasicPrimitiveWorkflow<P> {
26    pub fn new(backend: Arc<LLMBackend>) -> Self {
27        Self {
28            primitive: P::default(),
29            base_req: CompletionRequest::new(backend),
30            result_can_be_none: false,
31            instruct_prompt: InstructPrompt::default(),
32        }
33    }
34
35    pub async fn return_primitive(&mut self) -> crate::Result<P::PrimitiveResult> {
36        self.result_can_be_none = false;
37        let res = self.return_result().await?;
38        if let Some(primitive_result) = res.primitive_result {
39            Ok(self.primitive.parse_to_primitive(&primitive_result)?)
40        } else {
41            Err(anyhow::format_err!("No result returned."))
42        }
43    }
44
45    pub async fn return_optional_primitive(&mut self) -> crate::Result<Option<P::PrimitiveResult>> {
46        self.result_can_be_none = true;
47        let res = self.return_result().await?;
48        if let Some(primitive_result) = res.primitive_result {
49            Ok(Some(self.primitive.parse_to_primitive(&primitive_result)?))
50        } else {
51            Ok(None)
52        }
53    }
54
55    pub async fn return_result(&mut self) -> crate::Result<BasicPrimitiveResult> {
56        self.result_can_be_none = false;
57        let mut flow = self.basic_primitive()?;
58        flow.run_all_rounds(&mut self.base_req).await?;
59        BasicPrimitiveResult::new(flow)
60    }
61
62    pub async fn return_optional_result(&mut self) -> crate::Result<BasicPrimitiveResult> {
63        self.result_can_be_none = true;
64        let mut flow = self.basic_primitive()?;
65        flow.run_all_rounds(&mut self.base_req).await?;
66        BasicPrimitiveResult::new(flow)
67    }
68
69    fn basic_primitive(&mut self) -> crate::Result<CascadeFlow> {
70        let mut flow = CascadeFlow::new("BasicPrimitive");
71        let task = self.instruct_prompt.build_instruct_prompt(false)?;
72
73        let step_config = StepConfig {
74            step_prefix: Some(format!(
75                "Generating {}:\n",
76                self.primitive.solution_description(self.result_can_be_none),
77            )),
78            stop_word_no_result: self
79                .primitive
80                .stop_word_result_is_none(self.result_can_be_none),
81            grammar: self.primitive.grammar(),
82            ..StepConfig::default()
83        };
84
85        flow.new_round(task).add_inference_step(&step_config);
86
87        Ok(flow)
88    }
89}
90
91impl<P: PrimitiveTrait> RequestConfigTrait for BasicPrimitiveWorkflow<P> {
92    fn config(&mut self) -> &mut RequestConfig {
93        &mut self.base_req.config
94    }
95
96    fn reset_request(&mut self) {
97        self.instruct_prompt.reset_instruct_prompt();
98        self.base_req.reset_completion_request();
99    }
100}
101
102impl<P: PrimitiveTrait> InstructPromptTrait for BasicPrimitiveWorkflow<P> {
103    fn instruct_prompt_mut(&mut self) -> &mut InstructPrompt {
104        &mut self.instruct_prompt
105    }
106}
107
108pub struct BasicPrimitiveWorkflowBuilder {
109    pub base_req: CompletionRequest,
110}
111
112impl BasicPrimitiveWorkflowBuilder {
113    pub fn new(backend: Arc<LLMBackend>) -> Self {
114        Self {
115            base_req: CompletionRequest::new(backend),
116        }
117    }
118
119    fn build<P: PrimitiveTrait>(self) -> BasicPrimitiveWorkflow<P> {
120        BasicPrimitiveWorkflow {
121            primitive: P::default(),
122            base_req: self.base_req,
123            result_can_be_none: false,
124            instruct_prompt: InstructPrompt::default(),
125        }
126    }
127}
128
129macro_rules! basic_primitive_workflow_primitive_impl {
130    ($($name:ident => $type:ty),*) => {
131        impl BasicPrimitiveWorkflowBuilder {
132            $(
133                pub fn $name(self) -> BasicPrimitiveWorkflow<$type> {
134                    self.build()
135                }
136            )*
137        }
138    }
139}
140
141basic_primitive_workflow_primitive_impl! {
142    boolean => BooleanPrimitive,
143    integer => IntegerPrimitive,
144    sentences => SentencesPrimitive,
145    words => WordsPrimitive,
146    exact_string => ExactStringPrimitive,
147    text_list => TextListPrimitive
148}
149
150#[derive(Clone)]
151pub struct BasicPrimitiveResult {
152    pub primitive_result: Option<String>,
153    pub duration: std::time::Duration,
154    pub workflow: CascadeFlow,
155}
156
157impl BasicPrimitiveResult {
158    pub fn new(flow: CascadeFlow) -> crate::Result<Self> {
159        let reason_result = BasicPrimitiveResult {
160            primitive_result: flow.primitive_result(),
161            duration: flow.duration,
162            workflow: flow,
163        };
164        Ok(reason_result)
165    }
166}