openai_func_embeddings/lib.rs
1use async_openai::{types::CreateEmbeddingRequestArgs, Client};
2use rkyv::{vec::ArchivedVec, Archive, Deserialize, Serialize};
3use std::fs::File;
4use std::io::Read;
5use std::path::Path;
6
7#[derive(Debug, Archive, Deserialize, Serialize)]
8#[archive(check_bytes)]
9#[archive_attr(derive(Debug))]
10pub struct FuncEmbedding {
11 pub name: String,
12 pub description: String,
13 pub embedding: Vec<f32>,
14}
15
16/// Asynchronously generates a single embedding vector for the given text using a specified model.
17///
18/// This function creates an embedding for the input text by calling an external service (e.g., OpenAI's
19/// API) with the specified model. It returns the embedding vector as a `Vec<f32>`.
20///
21/// # Parameters
22/// - `text`: A reference to a `String` containing the text to be embedded.
23/// - `model`: A string slice (`&str`) specifying the model to use for generating the embedding.
24///
25/// # Returns
26/// A `Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>`:
27/// - `Ok(Vec<f32>)` containing the embedding vector if the operation is successful.
28/// - `Err(Box<dyn std::error::Error + Send + Sync>)` if there is an error during the operation,
29/// including issues with creating the request, network errors, or if the response does not contain an embedding.
30///
31/// # Errors
32/// This function can return an error in several cases, including:
33/// - Failure to build the embedding request.
34/// - Network or API errors when contacting the external service.
35/// - The response from the external service does not include an embedding vector.
36///
37/// # Example
38/// ```rust
39/// use std::path::Path;
40/// use your_module::{single_embedding, FuncEnumsError, Client, CreateEmbeddingRequestArgs};
41///
42/// #[tokio::main]
43/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
44/// let text = String::from("Your sample text here");
45/// let model = "your-model-name";
46///
47/// let embedding = single_embedding(&text, model).await?;
48/// println!("Embedding vector: {:?}", embedding);
49///
50/// Ok(())
51/// }
52/// ```
53pub async fn single_embedding(
54 text: &String,
55 model: &str,
56) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
57 let client = Client::new();
58 let request = CreateEmbeddingRequestArgs::default()
59 .model(model)
60 .input([text])
61 .build()?;
62
63 let response = client.embeddings().create(request).await?;
64
65 match response.data.first() {
66 Some(data) => Ok(data.embedding.to_owned()),
67 None => {
68 let embedding_error =
69 FuncEnumsError::OpenAIError(String::from("Didn't get embedding vector back."));
70 let boxed_error: Box<dyn std::error::Error + Send + Sync> = Box::new(embedding_error);
71 Err(boxed_error)
72 }
73 }
74}
75
76pub fn cosine_similarity(vec1: &[f32], vec2: &[f32]) -> f32 {
77 let dot_product: f32 = vec1.iter().zip(vec2.iter()).map(|(&x1, &x2)| x1 * x2).sum();
78 let magnitude1: f32 = vec1.iter().map(|&x| x.powf(2.0)).sum::<f32>().sqrt();
79 let magnitude2: f32 = vec2.iter().map(|&x| x.powf(2.0)).sum::<f32>().sqrt();
80
81 if magnitude1 == 0.0 || magnitude2 == 0.0 {
82 return 0.0;
83 }
84
85 dot_product / (magnitude1 * magnitude2)
86}
87
88pub async fn rank_functions(
89 archived_embeddings: &ArchivedVec<ArchivedFuncEmbedding>,
90 input_vector: Vec<f32>,
91) -> Vec<String> {
92 let mut name_similarity_pairs: Vec<(String, f32)> = archived_embeddings
93 .iter()
94 .map(|archived_embedding| {
95 let archived_embedding_vec: &ArchivedVec<f32> = &archived_embedding.embedding;
96 let similarity = cosine_similarity(archived_embedding_vec, &input_vector);
97 (archived_embedding.name.to_string(), similarity)
98 })
99 .collect();
100
101 name_similarity_pairs
102 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
103
104 name_similarity_pairs
105 .into_iter()
106 .map(|(name, _)| name)
107 .collect()
108}
109
110/// Asynchronously retrieves and ranks function names based on their similarity to a given prompt embedding.
111///
112/// This function searches a specified file for function embeddings, compares them to the provided prompt embedding, and returns a ranked list of function names based on their similarity to the prompt.
113///
114/// # Parameters
115/// - `prompt_embedding`: A `Vec<f32>` representing the embedding of the prompt. This embedding is used to compare against the function embeddings stored in the file located at `embed_path`.
116/// - `embed_path`: A reference to a `Path` where the function embeddings are stored. This file should contain a serialized `Vec<FuncEmbedding>` where `FuncEmbedding` is a structure representing the function name and its embedding.
117///
118/// # Returns
119/// - `Ok(Vec<String>)`: A vector of function names ranked by their similarity to the `prompt_embedding`. The most similar function's name is first.
120/// - `Err(Box<dyn std::error::Error + Send + Sync>)`: An error if the file at `embed_path` cannot be opened, read, or if the embeddings cannot be deserialized and compared successfully.
121///
122/// # Errors
123/// - File opening failure due to `embed_path` not existing or being inaccessible.
124/// - File reading failure if the file cannot be read to the end.
125/// - Archive processing failure if deserialization of the stored embeddings encounters errors.
126///
127/// # Examples
128/// ```
129/// async fn run() -> Result<(), Box<dyn std::error::Error>> {
130/// let prompt_embedding = vec![0.1, 0.2, 0.3];
131/// let embed_path = Path::new("function_embeddings.bin");
132/// let ranked_function_names = get_ranked_function_names(prompt_embedding, embed_path).await?;
133/// println!("Ranked functions: {:?}", ranked_function_names);
134/// Ok(())
135/// }
136/// ```
137pub async fn get_ranked_function_names(
138 prompt_embedding: Vec<f32>,
139 embed_path: &Path,
140) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
141 if embed_path.exists() {
142 let mut file = match File::open(embed_path) {
143 Ok(f) => f,
144 Err(e) => return Err(Box::new(e)),
145 };
146
147 let mut bytes = Vec::new();
148 if let Err(e) = file.read_to_end(&mut bytes) {
149 return Err(Box::new(e));
150 }
151
152 // TODO: Would be nice to check how much faster unsafe version of this is.
153 let archived_funcs =
154 rkyv::check_archived_root::<Vec<FuncEmbedding>>(&bytes).map_err(|e| {
155 Box::new(FuncEnumsError::RkyvError(format!(
156 "Archive processing failed: {}",
157 e
158 ))) as Box<dyn std::error::Error + Send + Sync>
159 })?;
160
161 Ok(rank_functions(archived_funcs, prompt_embedding).await)
162 } else {
163 Ok(vec![])
164 }
165}
166
167#[derive(Debug)]
168pub enum FuncEnumsError {
169 OpenAIError(String),
170 RkyvError(String),
171}
172
173impl std::fmt::Display for FuncEnumsError {
174 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
175 write!(f, "{:?}", self)
176 }
177}
178
179impl std::error::Error for FuncEnumsError {}