kalosm_language/search/preprocessing/
hypothetical.rs

1use kalosm_language_model::{CreateChatSession, Embedder, StructuredChatModel};
2use kalosm_sample::{IndexParser, LiteralParser, ParserExt, StopOn};
3
4use crate::{
5    prelude::{Document, Task},
6    search::Chunk,
7};
8
9use super::{ChunkStrategy, Chunker};
10
11const TASK_DESCRIPTION: &str =
12    "You generate hypothetical questions that may be answered by the given text. The questions restate any information necessary to understand the question";
13
14const EXAMPLES: [(&str, &str); 2] = [("A content delivery network or a CDN optimizes the distribution of web content by strategically placing servers worldwide. This reduces latency, accelerates content delivery, and enhances the overall user experience.", "What role does a content delivery network play in web performance?"), ("The Internet of Things or IoT connects everyday devices to the internet, enabling them to send and receive data. This connectivity enhances automation and allows for more efficient monitoring and control of various systems.", "What is the purpose of the Internet of Things?")];
15
16const QUESTION_STARTERS: [&str; 9] = [
17    "Who", "What", "When", "Where", "Why", "How", "Which", "Whom", "Whose",
18];
19
20const PREFIX: &str = "Questions that are answered by the previous text: ";
21
22type Constraints = kalosm_sample::SequenceParser<
23    LiteralParser,
24    kalosm_sample::RepeatParser<
25        kalosm_sample::SequenceParser<IndexParser<LiteralParser>, StopOn<&'static str>>,
26    >,
27>;
28
29fn create_constraints() -> Constraints {
30    LiteralParser::new(PREFIX).then(
31        IndexParser::new(
32            QUESTION_STARTERS
33                .iter()
34                .copied()
35                .map(LiteralParser::new)
36                .collect::<Vec<_>>(),
37        )
38        .then(StopOn::new("?").filter_characters(
39            |c| matches!(c, ' ' | '?' | 'a'..='z' | 'A'..='Z' | '0'..='9' | ','),
40        ))
41        .repeat(1..=5),
42    )
43}
44
45/// A builder for a hypothetical chunker.
46pub struct HypotheticalBuilder<M: CreateChatSession> {
47    model: M,
48    task_description: Option<String>,
49    examples: Option<Vec<(String, String)>>,
50    chunking: Option<ChunkStrategy>,
51}
52
53impl<M: CreateChatSession> HypotheticalBuilder<M> {
54    /// Set the chunking strategy.
55    pub fn with_chunking(mut self, chunking: ChunkStrategy) -> Self {
56        self.chunking = Some(chunking);
57        self
58    }
59
60    /// Set the examples for this task. Each example should include the text and the questions that are answered by the text.
61    pub fn with_examples<S: Into<String>>(
62        mut self,
63        examples: impl IntoIterator<Item = (S, S)>,
64    ) -> Self {
65        self.examples = Some(
66            examples
67                .into_iter()
68                .map(|(a, b)| (a.into(), { PREFIX.to_string() + &b.into() }))
69                .collect::<Vec<_>>(),
70        );
71        self
72    }
73
74    /// Set the task description. The task description should describe a task of generating hypothetical questions that may be answered by the given text.
75    pub fn with_task_description(mut self, task_description: String) -> Self {
76        self.task_description = Some(task_description);
77        self
78    }
79
80    /// Build the hypothetical chunker.
81    pub fn build(self) -> Hypothetical<M> {
82        let task_description = self
83            .task_description
84            .unwrap_or_else(|| TASK_DESCRIPTION.to_string());
85        let examples = self.examples.unwrap_or_else(|| {
86            EXAMPLES
87                .iter()
88                .map(|(a, b)| (a.to_string(), { PREFIX.to_string() + b }))
89                .collect::<Vec<_>>()
90        });
91        let chunking = self.chunking;
92
93        let task = Task::new(self.model, task_description).with_examples(examples);
94
95        Hypothetical { chunking, task }
96    }
97}
98
99/// Generates questions for a document.
100pub struct Hypothetical<M: CreateChatSession> {
101    chunking: Option<ChunkStrategy>,
102    task: Task<M>,
103}
104
105impl<M: CreateChatSession> Hypothetical<M> {
106    /// Create a new hypothetical generator.
107    pub fn builder(model: M) -> HypotheticalBuilder<M> {
108        HypotheticalBuilder {
109            model,
110            task_description: None,
111            examples: None,
112            chunking: None,
113        }
114    }
115
116    /// Generate a list of hypothetical questions about the given text.
117    pub async fn generate_question(&self, text: &str) -> Result<Vec<String>, M::Error>
118    where
119        M: StructuredChatModel<Constraints> + Send + Sync + Clone + Unpin + 'static,
120        M::ChatSession: Clone + Send + Sync + Unpin + 'static,
121        M::Error: Send + Sync + Unpin,
122    {
123        let questions = self
124            .task
125            .run(text)
126            .with_constraints(create_constraints())
127            .await?;
128        let documents = questions
129            .1
130            .into_iter()
131            .map(|((i, _), s)| QUESTION_STARTERS[i].to_string() + &s)
132            .collect::<Vec<_>>();
133
134        Ok(documents)
135    }
136}
137
138/// An error that can occur when chunking a document with [`HypotheticalChunker`].
139#[derive(Debug, thiserror::Error)]
140pub enum HypotheticalChunkerError<E1: Send + Sync + 'static, E2: Send + Sync + 'static> {
141    /// An error from the text generation model.
142    #[error("Text generation model error: {0}")]
143    TextModelError(#[from] E1),
144    /// An error from the embedding model.
145    #[error("Embedding model error: {0}")]
146    EmbeddingModelError(E2),
147}
148
149impl<M> Chunker for Hypothetical<M>
150where
151    M: StructuredChatModel<Constraints> + Send + Sync + Clone + Unpin + 'static,
152    M::ChatSession: Clone + Send + Sync + Unpin + 'static,
153    M::Error: Send + Sync + Unpin,
154{
155    type Error<E: Send + Sync + 'static> = HypotheticalChunkerError<M::Error, E>;
156
157    async fn chunk<E: Embedder + Send>(
158        &self,
159        document: &Document,
160        embedder: &E,
161    ) -> Result<Vec<Chunk>, Self::Error<E::Error>> {
162        let body = document.body();
163
164        #[allow(clippy::single_range_in_vec_init)]
165        let byte_chunks = self
166            .chunking
167            .map(|chunking| chunking.chunk_str(body))
168            .unwrap_or_else(|| vec![0..body.len()]);
169
170        if byte_chunks.is_empty() {
171            return Ok(vec![]);
172        }
173
174        let mut questions = Vec::new();
175        let mut questions_count = Vec::new();
176        for byte_chunk in &byte_chunks {
177            let text = &body[byte_chunk.clone()];
178            let mut chunk_questions = self.generate_question(text).await?;
179            questions.append(&mut chunk_questions);
180            questions_count.push(chunk_questions.len());
181        }
182        let embeddings = embedder
183            .embed_vec(questions)
184            .await
185            .map_err(HypotheticalChunkerError::EmbeddingModelError)?;
186
187        let mut chunks = Vec::with_capacity(embeddings.len());
188        let mut questions_count = questions_count.iter();
189        let mut byte_chunks = byte_chunks.into_iter();
190
191        let mut remaining_embeddings = *questions_count.next().unwrap();
192        let mut byte_chunk = byte_chunks.next().unwrap();
193
194        for embedding in embeddings {
195            while remaining_embeddings == 0 {
196                if let Some(&questions_count) = questions_count.next() {
197                    remaining_embeddings = questions_count;
198                    byte_chunk = byte_chunks.next().unwrap();
199                }
200            }
201            remaining_embeddings -= 1;
202            chunks.push(Chunk {
203                byte_range: byte_chunk.clone(),
204                embeddings: vec![embedding],
205            });
206        }
207
208        Ok(chunks)
209    }
210}