alith_client/workflows/reason/
mod.rs

1pub mod decision;
2pub mod one_round;
3
4use crate::{
5    components::{cascade::CascadeFlow, instruct_prompt::InstructPrompt},
6    primitives::*,
7};
8use alith_interface::{llms::LLMBackend, requests::completion::CompletionRequest};
9use one_round::ReasonOneRound;
10use std::sync::Arc;
11
12pub trait ReasonTrait: PrimitiveTrait {
13    fn primitive_to_result_index(&self, content: &str) -> u32;
14
15    fn result_index_to_primitive(
16        &self,
17        result_index: Option<u32>,
18    ) -> crate::Result<Option<Self::PrimitiveResult>>;
19
20    fn parse_reason_result(
21        &self,
22        reason_result: &ReasonResult,
23    ) -> crate::Result<Option<Self::PrimitiveResult>> {
24        if let Some(result_index) = reason_result.result_index {
25            self.result_index_to_primitive(Some(result_index))
26        } else {
27            Ok(None)
28        }
29    }
30}
31
32pub struct ReasonWorkflowBuilder {
33    pub base_req: CompletionRequest,
34}
35
36impl ReasonWorkflowBuilder {
37    pub fn new(backend: Arc<LLMBackend>) -> Self {
38        Self {
39            base_req: CompletionRequest::new(backend),
40        }
41    }
42
43    fn build<P: PrimitiveTrait>(self) -> ReasonOneRound<P> {
44        ReasonOneRound {
45            primitive: P::default(),
46            base_req: self.base_req,
47            reasoning_sentences: 3,
48            conclusion_sentences: 2,
49            result_can_be_none: false,
50            instruct_prompt: InstructPrompt::default(),
51        }
52    }
53}
54
55macro_rules! reason_workflow_primitive_impl {
56    ($($name:ident => $type:ty),*) => {
57        impl ReasonWorkflowBuilder {
58            $(
59                pub fn $name(self) -> ReasonOneRound<$type> {
60                    self.build()
61                }
62            )*
63        }
64    }
65}
66
67reason_workflow_primitive_impl! {
68    boolean => BooleanPrimitive,
69    integer => IntegerPrimitive,
70    exact_string => ExactStringPrimitive
71}
72
73#[derive(Clone)]
74pub struct ReasonResult {
75    pub primitive_result: Option<String>,
76    pub duration: std::time::Duration,
77    pub workflow: CascadeFlow,
78    pub result_index: Option<u32>,
79    pub temperature: f32,
80}
81
82impl ReasonResult {
83    fn new<P: PrimitiveTrait + ReasonTrait>(
84        flow: CascadeFlow,
85        primitive: &P,
86        base_req: &CompletionRequest,
87    ) -> crate::Result<Self> {
88        let primitive_result = flow.primitive_result();
89        let result_index = primitive_result
90            .as_ref()
91            .map(|primitive_result| primitive.primitive_to_result_index(primitive_result));
92        Ok(ReasonResult {
93            primitive_result,
94            duration: flow.duration,
95            workflow: flow,
96            result_index,
97            temperature: base_req.config.temperature,
98        })
99    }
100}
101
102impl std::fmt::Display for ReasonResult {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        writeln!(f, "{}", self.workflow)?;
105        writeln!(
106            f,
107            "{}Reason duration\x1b[0m: {:?}",
108            SETTINGS_GRADIENT[2], self.duration
109        )?;
110        writeln!(
111            f,
112            "{}Reason temperature\x1b[0m: {:?}",
113            SETTINGS_GRADIENT[1], self.temperature
114        )?;
115        if let Some(primitive_result) = &self.primitive_result {
116            writeln!(
117                f,
118                "{}Reason primitive_result\x1b[0m: {}",
119                SETTINGS_GRADIENT[0], primitive_result
120            )?;
121        } else {
122            writeln!(
123                f,
124                "{}Reason primitive_result\x1b[0m: None",
125                SETTINGS_GRADIENT[0],
126            )?;
127        };
128        Ok(())
129    }
130}
131
132static SETTINGS_GRADIENT: std::sync::LazyLock<Vec<&'static str>> = std::sync::LazyLock::new(|| {
133    vec![
134        "\x1B[38;2;92;244;37m",
135        "\x1B[38;2;0;239;98m",
136        "\x1B[38;2;0;225;149m",
137        "\x1B[38;2;0;212;178m",
138        "\x1B[38;2;0;201;196m",
139        "\x1B[38;2;0;190;207m",
140        "\x1B[38;2;0;180;215m",
141        "\x1B[38;2;0;170;222m",
142        "\x1B[38;2;0;159;235m",
143        "\x1B[38;2;0;142;250m",
144    ]
145});