chatgpt_private_test/
utils.rs

1use async_openai::{
2    config::OpenAIConfig,
3    types::{
4        ChatCompletionRequestMessage, CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs,
5        CreateEmbeddingResponse, Role,
6    },
7    Client,
8};
9use qdrant_client::prelude::{Payload, QdrantClient, QdrantClientConfig};
10use qdrant_client::qdrant::{
11    point_id::PointIdOptions, vectors::VectorsOptions, vectors_config::Config, CreateCollection,
12    Distance, PointId, PointStruct, SearchPoints, Vector, VectorParams, Vectors, VectorsConfig,
13};
14use rand::Rng;
15use serde::Serialize;
16use std::default::Default;
17use tiktoken_rs::cl100k_base;
18
19pub async fn search_with_openai(
20    query: &str,
21    collection_name: &str,
22    api_key: &str,
23) -> anyhow::Result<()> {
24    let config = OpenAIConfig::new().with_api_key(api_key);
25    let openai_client = Client::with_config(config);
26
27    let config = QdrantClientConfig::from_url("http://127.0.0.1:6334");
28    let qdrant_client = QdrantClient::new(Some(config))?;
29
30    let request = CreateEmbeddingRequestArgs::default()
31        .model("text-embedding-ada-002")
32        .input(query)
33        .build()
34        .unwrap();
35
36    let response: CreateEmbeddingResponse = openai_client.embeddings().create(request).await?;
37    let question_vector = response.data[0].clone().embedding;
38
39    let search_result = qdrant_client
40        .search_points(&SearchPoints {
41            collection_name: collection_name.into(),
42            vector: question_vector,
43            filter: None,
44            limit: 10,
45            with_vectors: None,
46            with_payload: None,
47            params: None,
48            score_threshold: None,
49            offset: None,
50            ..Default::default()
51        })
52        .await?;
53    dbg!(search_result);
54
55    Ok(())
56}
57
58pub async fn load_text() -> anyhow::Result<()> {
59    let bpe = cl100k_base().unwrap();
60
61    // let s = include_str!("book.txt");
62    let s = "book.txt";
63
64    let chunked_text = bpe
65        .encode_ordinary(&convert(s))
66        .chunks(4500)
67        .map(|c| bpe.decode(c.to_vec()).unwrap())
68        .collect::<Vec<String>>();
69
70    let mut ids_vec = (0..10000u64).into_iter().rev().collect::<Vec<u64>>();
71
72    for chunk in chunked_text {
73        // if let Ok(segment) = segment_text(&chunk, api_key).await {
74        //     for seg in &segment {
75        //         println!("{}\n", seg);
76        //     }
77        // }
78    }
79
80    Ok(())
81}
82
83pub async fn init_collection(collection_name: &str) -> anyhow::Result<()> {
84    let config = QdrantClientConfig::from_url("http://127.0.0.1:6334");
85    let qdrant_client = QdrantClient::new(Some(config))?;
86
87    qdrant_client
88        .create_collection(&CreateCollection {
89            collection_name: collection_name.into(),
90            vectors_config: Some(VectorsConfig {
91                config: Some(Config::Params(VectorParams {
92                    size: 1536,
93                    distance: Distance::Cosine.into(),
94                    hnsw_config: None,
95                    quantization_config: None,
96                    on_disk: None,
97                })),
98            }),
99            ..Default::default()
100        })
101        .await?;
102
103    Ok(())
104}
105
106pub async fn upload_embeddings(
107    inp: Vec<String>,
108    ids_vec: &mut Vec<u64>,
109    collection_name: &str,
110    api_key: &str,
111) -> anyhow::Result<()> {
112    let config = OpenAIConfig::new().with_api_key(api_key);
113    let openai_client = Client::with_config(config);
114    let config = QdrantClientConfig::from_url("http://127.0.0.1:6334");
115    let qdrant_client = QdrantClient::new(Some(config))?;
116
117    let request = CreateEmbeddingRequestArgs::default()
118        .model("text-embedding-ada-002")
119        .input(&inp)
120        .build()?;
121
122    let response: CreateEmbeddingResponse = openai_client.embeddings().create(request).await?;
123    let embeddings = response.data;
124    let mut points = Vec::new();
125
126    for (i, sentence) in inp.iter().enumerate() {
127        let id = ids_vec.pop().unwrap();
128        let payload: Payload = serde_json::json!({ "text": sentence.trim().to_string()})
129            .try_into()
130            .unwrap();
131        let point = PointStruct::new(
132            PointId {
133                point_id_options: Some(PointIdOptions::Num(id)),
134            },
135            Vectors {
136                vectors_options: Some(VectorsOptions::Vector(Vector {
137                    data: embeddings[i].clone().embedding,
138                })),
139            },
140            payload,
141        );
142
143        points.push(point);
144    }
145
146    qdrant_client
147        .upsert_points_blocking(collection_name, points, None)
148        .await?;
149    Ok(())
150}
151
152pub async fn segment_text(inp: &str, api_key: &str) -> anyhow::Result<Vec<String>> {
153    let config = OpenAIConfig::new().with_api_key(api_key);
154    let openai_client = Client::with_config(config);
155
156    let prompt = format!(
157        r#"You are examining Chapter 1 of a book. Your mission is to dissect the provided information into short, logically divided segments to facilitate further processing afterwards. 
158    Please adhere to the following steps:
159    1. Break down dense paragraphs into individual sentences, with each one functioning as a distinct chunk of information. 
160    2. Consider code snippets as standalone entities and separate them from the accompanying text, break down long code snippets to chunks of less than 15 lines, please respect the programming language constructs that keep a group of codes together or separate one group of codes from another. 
161    3. Take into account the original source's hierarchical markings and formatting specific to a book chapter. These elements can guide the logical segmentation process.
162    Keep in mind, the goal is not to summarize, but to restructure the information into more digestible, manageable units. Now, here is the text from the chapter:{inp}".
163    Please reply in this format:
164```
165<sentence>~>_^~<sentence>~>_^~<sentence>
166```"#
167    );
168    let system_message = ChatCompletionRequestMessage {
169        role: Role::System,
170        content: Some("As a dedicated assistant, your duty is to dissect the provided chapter text into clearer, bite-sized segments. To accomplish this, isolate each sentence and code snippet as independent entities. Remember, your task is not to provide a summary, but to split the original text into a texts sequence more granunlar, respecting the text's hierarchical markings and formatting as they contribute to the understanding of the text. Balance your interpretations with the original structure for an accurate representation. reply in this format: ```<sentence>~>_^~<sentence>~>_^~<sentence>```".to_string()),
171        name: None,
172        function_call: None,
173};
174
175    let user_message = ChatCompletionRequestMessage {
176        role: Role::User,
177        content: Some(prompt),
178        name: None,
179        function_call: None,
180    };
181
182    let request = CreateChatCompletionRequestArgs::default()
183        .model("gpt-3.5-turbo-16k")
184        .messages(vec![system_message, user_message])
185        .max_tokens(7000_u16)
186        .build()?;
187
188    let response = openai_client
189        .chat() // Get the API "group" (completions, images, etc.) from the client
190        .create(request) // Make the API call in that "group"
191        .await?;
192
193    match &response.choices[0].message.content {
194        Some(raw_text) => Ok(raw_text
195            .split("~>_^~")
196            .map(|x| x.to_string())
197            .collect::<Vec<_>>()),
198        None => Err(anyhow::anyhow!("Could not get the text from OpenAI")),
199    }
200}
201
202struct EscapeNonAscii;
203
204impl serde_json::ser::Formatter for EscapeNonAscii {
205    fn write_string_fragment<W: ?Sized + std::io::Write>(
206        &mut self,
207        writer: &mut W,
208        fragment: &str,
209    ) -> std::io::Result<()> {
210        for ch in fragment.chars() {
211            if ch.is_ascii() {
212                writer.write_all(ch.encode_utf8(&mut [0; 4]).as_bytes())?;
213            } else {
214                write!(writer, "\\u{:04x}", ch as u32)?;
215                // write!(writer, "?")?;
216            }
217        }
218        Ok(())
219    }
220}
221
222pub fn convert(input: &str) -> String {
223    let mut writer = Vec::new();
224    let formatter = EscapeNonAscii;
225    let mut ser = serde_json::Serializer::with_formatter(&mut writer, formatter);
226    input.serialize(&mut ser).unwrap();
227    String::from_utf8(writer).unwrap()
228}
229
230pub fn gen_ids() -> Vec<u64> {
231    let mut rng = rand::thread_rng();
232    let mut set = std::collections::HashSet::new();
233
234    while set.len() < 9900 {
235        set.insert(rng.gen::<u64>());
236    }
237    set.into_iter().collect::<Vec<u64>>()
238}