alith_interface/requests/
stop_sequence.rs

1#[derive(Debug, Clone)]
2pub enum StoppingSequence {
3    InferenceDone(String),
4    NoResult(String),
5}
6
7impl PartialEq for StoppingSequence {
8    fn eq(&self, other: &Self) -> bool {
9        std::mem::discriminant(self) == std::mem::discriminant(other)
10    }
11}
12
13impl StoppingSequence {
14    pub fn as_str(&self) -> &str {
15        match self {
16            StoppingSequence::InferenceDone(s) => s,
17            StoppingSequence::NoResult(s) => s,
18        }
19    }
20}
21
22#[derive(Default, Debug, Clone)]
23pub struct StopSequences {
24    pub sequences: Vec<StoppingSequence>,
25    pub required: bool,
26}
27
28impl StopSequences {
29    pub fn new() -> Self {
30        Self {
31            sequences: Vec::new(),
32            required: false,
33        }
34    }
35
36    pub fn to_vec(&self) -> Vec<String> {
37        self.sequences
38            .iter()
39            .map(|sw| sw.as_str().to_owned())
40            .collect()
41    }
42
43    pub fn parse_string_response<T: AsRef<str>>(
44        &self,
45        response_stop_word: T,
46    ) -> Option<StoppingSequence> {
47        for stop_word in &self.sequences {
48            if response_stop_word.as_ref() == stop_word.as_str() {
49                return Some(stop_word.clone());
50            }
51        }
52        None
53    }
54
55    pub fn parse_option_response<T: AsRef<str>>(
56        &self,
57        response_stop_word: &Option<T>,
58    ) -> Option<StoppingSequence> {
59        match response_stop_word {
60            Some(stop_word) => self.parse_string_response(stop_word),
61            None => None,
62        }
63    }
64
65    pub fn error_on_required(&self) -> String {
66        format!(
67            "One of the sequences: {} is required, but response stopping_word is None.",
68            self.sequences
69                .iter()
70                .map(|sw| sw.as_str())
71                .collect::<Vec<&str>>()
72                .join(", ")
73        )
74    }
75
76    pub fn add_stop_word<T: AsRef<str>>(&self, stop_word: T) -> bool {
77        self.sequences.is_empty()
78            || !self
79                .sequences
80                .iter()
81                .any(|s| s.as_str() == stop_word.as_ref())
82    }
83
84    pub fn set_stop_word_done<T: AsRef<str>>(&mut self, stop_word: T) -> &mut Self {
85        if self.add_stop_word(&stop_word) {
86            self.sequences.push(StoppingSequence::InferenceDone(
87                stop_word.as_ref().to_owned(),
88            ));
89        }
90        self
91    }
92
93    pub fn set_stop_word_no_result<T: AsRef<str>>(&mut self, stop_word: T) -> &mut Self {
94        if self.add_stop_word(&stop_word) {
95            self.sequences
96                .push(StoppingSequence::NoResult(stop_word.as_ref().to_owned()));
97        }
98        self
99    }
100}