kalosm_language/search/preprocessing/
hypothetical.rs1use kalosm_language_model::{CreateChatSession, Embedder, StructuredChatModel};
2use kalosm_sample::{IndexParser, LiteralParser, ParserExt, StopOn};
3
4use crate::{
5 prelude::{Document, Task},
6 search::Chunk,
7};
8
9use super::{ChunkStrategy, Chunker};
10
11const TASK_DESCRIPTION: &str =
12 "You generate hypothetical questions that may be answered by the given text. The questions restate any information necessary to understand the question";
13
14const EXAMPLES: [(&str, &str); 2] = [("A content delivery network or a CDN optimizes the distribution of web content by strategically placing servers worldwide. This reduces latency, accelerates content delivery, and enhances the overall user experience.", "What role does a content delivery network play in web performance?"), ("The Internet of Things or IoT connects everyday devices to the internet, enabling them to send and receive data. This connectivity enhances automation and allows for more efficient monitoring and control of various systems.", "What is the purpose of the Internet of Things?")];
15
16const QUESTION_STARTERS: [&str; 9] = [
17 "Who", "What", "When", "Where", "Why", "How", "Which", "Whom", "Whose",
18];
19
20const PREFIX: &str = "Questions that are answered by the previous text: ";
21
22type Constraints = kalosm_sample::SequenceParser<
23 LiteralParser,
24 kalosm_sample::RepeatParser<
25 kalosm_sample::SequenceParser<IndexParser<LiteralParser>, StopOn<&'static str>>,
26 >,
27>;
28
29fn create_constraints() -> Constraints {
30 LiteralParser::new(PREFIX).then(
31 IndexParser::new(
32 QUESTION_STARTERS
33 .iter()
34 .copied()
35 .map(LiteralParser::new)
36 .collect::<Vec<_>>(),
37 )
38 .then(StopOn::new("?").filter_characters(
39 |c| matches!(c, ' ' | '?' | 'a'..='z' | 'A'..='Z' | '0'..='9' | ','),
40 ))
41 .repeat(1..=5),
42 )
43}
44
45pub struct HypotheticalBuilder<M: CreateChatSession> {
47 model: M,
48 task_description: Option<String>,
49 examples: Option<Vec<(String, String)>>,
50 chunking: Option<ChunkStrategy>,
51}
52
53impl<M: CreateChatSession> HypotheticalBuilder<M> {
54 pub fn with_chunking(mut self, chunking: ChunkStrategy) -> Self {
56 self.chunking = Some(chunking);
57 self
58 }
59
60 pub fn with_examples<S: Into<String>>(
62 mut self,
63 examples: impl IntoIterator<Item = (S, S)>,
64 ) -> Self {
65 self.examples = Some(
66 examples
67 .into_iter()
68 .map(|(a, b)| (a.into(), { PREFIX.to_string() + &b.into() }))
69 .collect::<Vec<_>>(),
70 );
71 self
72 }
73
74 pub fn with_task_description(mut self, task_description: String) -> Self {
76 self.task_description = Some(task_description);
77 self
78 }
79
80 pub fn build(self) -> Hypothetical<M> {
82 let task_description = self
83 .task_description
84 .unwrap_or_else(|| TASK_DESCRIPTION.to_string());
85 let examples = self.examples.unwrap_or_else(|| {
86 EXAMPLES
87 .iter()
88 .map(|(a, b)| (a.to_string(), { PREFIX.to_string() + b }))
89 .collect::<Vec<_>>()
90 });
91 let chunking = self.chunking;
92
93 let task = Task::new(self.model, task_description).with_examples(examples);
94
95 Hypothetical { chunking, task }
96 }
97}
98
99pub struct Hypothetical<M: CreateChatSession> {
101 chunking: Option<ChunkStrategy>,
102 task: Task<M>,
103}
104
105impl<M: CreateChatSession> Hypothetical<M> {
106 pub fn builder(model: M) -> HypotheticalBuilder<M> {
108 HypotheticalBuilder {
109 model,
110 task_description: None,
111 examples: None,
112 chunking: None,
113 }
114 }
115
116 pub async fn generate_question(&self, text: &str) -> Result<Vec<String>, M::Error>
118 where
119 M: StructuredChatModel<Constraints> + Send + Sync + Clone + Unpin + 'static,
120 M::ChatSession: Clone + Send + Sync + Unpin + 'static,
121 M::Error: Send + Sync + Unpin,
122 {
123 let questions = self
124 .task
125 .run(text)
126 .with_constraints(create_constraints())
127 .await?;
128 let documents = questions
129 .1
130 .into_iter()
131 .map(|((i, _), s)| QUESTION_STARTERS[i].to_string() + &s)
132 .collect::<Vec<_>>();
133
134 Ok(documents)
135 }
136}
137
138#[derive(Debug, thiserror::Error)]
140pub enum HypotheticalChunkerError<E1: Send + Sync + 'static, E2: Send + Sync + 'static> {
141 #[error("Text generation model error: {0}")]
143 TextModelError(#[from] E1),
144 #[error("Embedding model error: {0}")]
146 EmbeddingModelError(E2),
147}
148
149impl<M> Chunker for Hypothetical<M>
150where
151 M: StructuredChatModel<Constraints> + Send + Sync + Clone + Unpin + 'static,
152 M::ChatSession: Clone + Send + Sync + Unpin + 'static,
153 M::Error: Send + Sync + Unpin,
154{
155 type Error<E: Send + Sync + 'static> = HypotheticalChunkerError<M::Error, E>;
156
157 async fn chunk<E: Embedder + Send>(
158 &self,
159 document: &Document,
160 embedder: &E,
161 ) -> Result<Vec<Chunk>, Self::Error<E::Error>> {
162 let body = document.body();
163
164 #[allow(clippy::single_range_in_vec_init)]
165 let byte_chunks = self
166 .chunking
167 .map(|chunking| chunking.chunk_str(body))
168 .unwrap_or_else(|| vec![0..body.len()]);
169
170 if byte_chunks.is_empty() {
171 return Ok(vec![]);
172 }
173
174 let mut questions = Vec::new();
175 let mut questions_count = Vec::new();
176 for byte_chunk in &byte_chunks {
177 let text = &body[byte_chunk.clone()];
178 let mut chunk_questions = self.generate_question(text).await?;
179 questions.append(&mut chunk_questions);
180 questions_count.push(chunk_questions.len());
181 }
182 let embeddings = embedder
183 .embed_vec(questions)
184 .await
185 .map_err(HypotheticalChunkerError::EmbeddingModelError)?;
186
187 let mut chunks = Vec::with_capacity(embeddings.len());
188 let mut questions_count = questions_count.iter();
189 let mut byte_chunks = byte_chunks.into_iter();
190
191 let mut remaining_embeddings = *questions_count.next().unwrap();
192 let mut byte_chunk = byte_chunks.next().unwrap();
193
194 for embedding in embeddings {
195 while remaining_embeddings == 0 {
196 if let Some(&questions_count) = questions_count.next() {
197 remaining_embeddings = questions_count;
198 byte_chunk = byte_chunks.next().unwrap();
199 }
200 }
201 remaining_embeddings -= 1;
202 chunks.push(Chunk {
203 byte_range: byte_chunk.clone(),
204 embeddings: vec![embedding],
205 });
206 }
207
208 Ok(chunks)
209 }
210}