1use async_openai::{
2 config::OpenAIConfig,
3 types::{
4 ChatCompletionRequestMessage, CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs,
5 CreateEmbeddingResponse, Role,
6 },
7 Client,
8};
9use qdrant_client::prelude::{Payload, QdrantClient, QdrantClientConfig};
10use qdrant_client::qdrant::{
11 point_id::PointIdOptions, vectors::VectorsOptions, vectors_config::Config, CreateCollection,
12 Distance, PointId, PointStruct, SearchPoints, Vector, VectorParams, Vectors, VectorsConfig,
13};
14use rand::Rng;
15use serde::Serialize;
16use std::default::Default;
17use tiktoken_rs::cl100k_base;
18
19pub async fn search_with_openai(
20 query: &str,
21 collection_name: &str,
22 api_key: &str,
23) -> anyhow::Result<()> {
24 let config = OpenAIConfig::new().with_api_key(api_key);
25 let openai_client = Client::with_config(config);
26
27 let config = QdrantClientConfig::from_url("http://127.0.0.1:6334");
28 let qdrant_client = QdrantClient::new(Some(config))?;
29
30 let request = CreateEmbeddingRequestArgs::default()
31 .model("text-embedding-ada-002")
32 .input(query)
33 .build()
34 .unwrap();
35
36 let response: CreateEmbeddingResponse = openai_client.embeddings().create(request).await?;
37 let question_vector = response.data[0].clone().embedding;
38
39 let search_result = qdrant_client
40 .search_points(&SearchPoints {
41 collection_name: collection_name.into(),
42 vector: question_vector,
43 filter: None,
44 limit: 10,
45 with_vectors: None,
46 with_payload: None,
47 params: None,
48 score_threshold: None,
49 offset: None,
50 ..Default::default()
51 })
52 .await?;
53 dbg!(search_result);
54
55 Ok(())
56}
57
58pub async fn load_text() -> anyhow::Result<()> {
59 let bpe = cl100k_base().unwrap();
60
61 let s = "book.txt";
63
64 let chunked_text = bpe
65 .encode_ordinary(&convert(s))
66 .chunks(4500)
67 .map(|c| bpe.decode(c.to_vec()).unwrap())
68 .collect::<Vec<String>>();
69
70 let mut ids_vec = (0..10000u64).into_iter().rev().collect::<Vec<u64>>();
71
72 for chunk in chunked_text {
73 }
79
80 Ok(())
81}
82
83pub async fn init_collection(collection_name: &str) -> anyhow::Result<()> {
84 let config = QdrantClientConfig::from_url("http://127.0.0.1:6334");
85 let qdrant_client = QdrantClient::new(Some(config))?;
86
87 qdrant_client
88 .create_collection(&CreateCollection {
89 collection_name: collection_name.into(),
90 vectors_config: Some(VectorsConfig {
91 config: Some(Config::Params(VectorParams {
92 size: 1536,
93 distance: Distance::Cosine.into(),
94 hnsw_config: None,
95 quantization_config: None,
96 on_disk: None,
97 })),
98 }),
99 ..Default::default()
100 })
101 .await?;
102
103 Ok(())
104}
105
106pub async fn upload_embeddings(
107 inp: Vec<String>,
108 ids_vec: &mut Vec<u64>,
109 collection_name: &str,
110 api_key: &str,
111) -> anyhow::Result<()> {
112 let config = OpenAIConfig::new().with_api_key(api_key);
113 let openai_client = Client::with_config(config);
114 let config = QdrantClientConfig::from_url("http://127.0.0.1:6334");
115 let qdrant_client = QdrantClient::new(Some(config))?;
116
117 let request = CreateEmbeddingRequestArgs::default()
118 .model("text-embedding-ada-002")
119 .input(&inp)
120 .build()?;
121
122 let response: CreateEmbeddingResponse = openai_client.embeddings().create(request).await?;
123 let embeddings = response.data;
124 let mut points = Vec::new();
125
126 for (i, sentence) in inp.iter().enumerate() {
127 let id = ids_vec.pop().unwrap();
128 let payload: Payload = serde_json::json!({ "text": sentence.trim().to_string()})
129 .try_into()
130 .unwrap();
131 let point = PointStruct::new(
132 PointId {
133 point_id_options: Some(PointIdOptions::Num(id)),
134 },
135 Vectors {
136 vectors_options: Some(VectorsOptions::Vector(Vector {
137 data: embeddings[i].clone().embedding,
138 })),
139 },
140 payload,
141 );
142
143 points.push(point);
144 }
145
146 qdrant_client
147 .upsert_points_blocking(collection_name, points, None)
148 .await?;
149 Ok(())
150}
151
152pub async fn segment_text(inp: &str, api_key: &str) -> anyhow::Result<Vec<String>> {
153 let config = OpenAIConfig::new().with_api_key(api_key);
154 let openai_client = Client::with_config(config);
155
156 let prompt = format!(
157 r#"You are examining Chapter 1 of a book. Your mission is to dissect the provided information into short, logically divided segments to facilitate further processing afterwards.
158 Please adhere to the following steps:
159 1. Break down dense paragraphs into individual sentences, with each one functioning as a distinct chunk of information.
160 2. Consider code snippets as standalone entities and separate them from the accompanying text, break down long code snippets to chunks of less than 15 lines, please respect the programming language constructs that keep a group of codes together or separate one group of codes from another.
161 3. Take into account the original source's hierarchical markings and formatting specific to a book chapter. These elements can guide the logical segmentation process.
162 Keep in mind, the goal is not to summarize, but to restructure the information into more digestible, manageable units. Now, here is the text from the chapter:{inp}".
163 Please reply in this format:
164```
165<sentence>~>_^~<sentence>~>_^~<sentence>
166```"#
167 );
168 let system_message = ChatCompletionRequestMessage {
169 role: Role::System,
170 content: Some("As a dedicated assistant, your duty is to dissect the provided chapter text into clearer, bite-sized segments. To accomplish this, isolate each sentence and code snippet as independent entities. Remember, your task is not to provide a summary, but to split the original text into a texts sequence more granunlar, respecting the text's hierarchical markings and formatting as they contribute to the understanding of the text. Balance your interpretations with the original structure for an accurate representation. reply in this format: ```<sentence>~>_^~<sentence>~>_^~<sentence>```".to_string()),
171 name: None,
172 function_call: None,
173};
174
175 let user_message = ChatCompletionRequestMessage {
176 role: Role::User,
177 content: Some(prompt),
178 name: None,
179 function_call: None,
180 };
181
182 let request = CreateChatCompletionRequestArgs::default()
183 .model("gpt-3.5-turbo-16k")
184 .messages(vec![system_message, user_message])
185 .max_tokens(7000_u16)
186 .build()?;
187
188 let response = openai_client
189 .chat() .create(request) .await?;
192
193 match &response.choices[0].message.content {
194 Some(raw_text) => Ok(raw_text
195 .split("~>_^~")
196 .map(|x| x.to_string())
197 .collect::<Vec<_>>()),
198 None => Err(anyhow::anyhow!("Could not get the text from OpenAI")),
199 }
200}
201
202struct EscapeNonAscii;
203
204impl serde_json::ser::Formatter for EscapeNonAscii {
205 fn write_string_fragment<W: ?Sized + std::io::Write>(
206 &mut self,
207 writer: &mut W,
208 fragment: &str,
209 ) -> std::io::Result<()> {
210 for ch in fragment.chars() {
211 if ch.is_ascii() {
212 writer.write_all(ch.encode_utf8(&mut [0; 4]).as_bytes())?;
213 } else {
214 write!(writer, "\\u{:04x}", ch as u32)?;
215 }
217 }
218 Ok(())
219 }
220}
221
222pub fn convert(input: &str) -> String {
223 let mut writer = Vec::new();
224 let formatter = EscapeNonAscii;
225 let mut ser = serde_json::Serializer::with_formatter(&mut writer, formatter);
226 input.serialize(&mut ser).unwrap();
227 String::from_utf8(writer).unwrap()
228}
229
230pub fn gen_ids() -> Vec<u64> {
231 let mut rng = rand::thread_rng();
232 let mut set = std::collections::HashSet::new();
233
234 while set.len() < 9900 {
235 set.insert(rng.gen::<u64>());
236 }
237 set.into_iter().collect::<Vec<u64>>()
238}