alith_interface/requests/
stop_sequence.rs1#[derive(Debug, Clone)]
2pub enum StoppingSequence {
3 InferenceDone(String),
4 NoResult(String),
5}
6
7impl PartialEq for StoppingSequence {
8 fn eq(&self, other: &Self) -> bool {
9 std::mem::discriminant(self) == std::mem::discriminant(other)
10 }
11}
12
13impl StoppingSequence {
14 pub fn as_str(&self) -> &str {
15 match self {
16 StoppingSequence::InferenceDone(s) => s,
17 StoppingSequence::NoResult(s) => s,
18 }
19 }
20}
21
22#[derive(Default, Debug, Clone)]
23pub struct StopSequences {
24 pub sequences: Vec<StoppingSequence>,
25 pub required: bool,
26}
27
28impl StopSequences {
29 pub fn new() -> Self {
30 Self {
31 sequences: Vec::new(),
32 required: false,
33 }
34 }
35
36 pub fn to_vec(&self) -> Vec<String> {
37 self.sequences
38 .iter()
39 .map(|sw| sw.as_str().to_owned())
40 .collect()
41 }
42
43 pub fn parse_string_response<T: AsRef<str>>(
44 &self,
45 response_stop_word: T,
46 ) -> Option<StoppingSequence> {
47 for stop_word in &self.sequences {
48 if response_stop_word.as_ref() == stop_word.as_str() {
49 return Some(stop_word.clone());
50 }
51 }
52 None
53 }
54
55 pub fn parse_option_response<T: AsRef<str>>(
56 &self,
57 response_stop_word: &Option<T>,
58 ) -> Option<StoppingSequence> {
59 match response_stop_word {
60 Some(stop_word) => self.parse_string_response(stop_word),
61 None => None,
62 }
63 }
64
65 pub fn error_on_required(&self) -> String {
66 format!(
67 "One of the sequences: {} is required, but response stopping_word is None.",
68 self.sequences
69 .iter()
70 .map(|sw| sw.as_str())
71 .collect::<Vec<&str>>()
72 .join(", ")
73 )
74 }
75
76 pub fn add_stop_word<T: AsRef<str>>(&self, stop_word: T) -> bool {
77 self.sequences.is_empty()
78 || !self
79 .sequences
80 .iter()
81 .any(|s| s.as_str() == stop_word.as_ref())
82 }
83
84 pub fn set_stop_word_done<T: AsRef<str>>(&mut self, stop_word: T) -> &mut Self {
85 if self.add_stop_word(&stop_word) {
86 self.sequences.push(StoppingSequence::InferenceDone(
87 stop_word.as_ref().to_owned(),
88 ));
89 }
90 self
91 }
92
93 pub fn set_stop_word_no_result<T: AsRef<str>>(&mut self, stop_word: T) -> &mut Self {
94 if self.add_stop_word(&stop_word) {
95 self.sequences
96 .push(StoppingSequence::NoResult(stop_word.as_ref().to_owned()));
97 }
98 self
99 }
100}