reco_forge/lib.rs
1//! This crate provides an interface for users to turn any dataset with titles and descriptions into a recommendation system. It uses the BERT model to create embeddings for each item in the dataset and then finds recommendations based on the user's input.
2//!
3//! To run the examples, you can clone the git repository at <https://github.com/jameslk3/reco-forge> and then run the following commands:
4//! ```bash
5//! cargo run --example description
6//! cargo run --example item
7//! ```
8//!
9//! Example usage:
10//! ```no_run
11//! use reco_forge::{create_model, pass_description, Data, Tensor, HashMap}; // Can also use pass_item
12//!
13//! fn main() -> Result<(), Box<dyn std::error::Error>> {
14//! let mut path = String::new();
15//! println!("Please enter the file path:");
16//!
17//! let model_wrapped: Result<HashMap<Data, Option<Tensor>>, ()>;
18//!
19//! loop {
20//! std::io::stdin().read_line(&mut path).expect("Failed to read line");
21//! path = path.trim().to_string();
22//!
23//! if let Ok(model) = create_model(&path) {
24//! model_wrapped = Ok(model);
25//! break;
26//! } else {
27//! println!("File path is not valid or file cannot be deserialized, please input the correct file path and try again:");
28//! path.clear();
29//! }
30//! }
31//!
32//! let model: HashMap<Data, Option<Tensor>> = model_wrapped.unwrap();
33//!
34//! println!("Input tags that you would like to use to filter, else enter NONE");
35//! let mut tags_input: String = String::new();
36//! std::io::stdin().read_line(&mut tags_input).expect("Failed to read line");
37//! tags_input = tags_input.trim().to_string();
38//!
39//! println!("Describe what you want to be recommended:");
40//! let mut query: String = String::new();
41//! std::io::stdin().read_line(&mut query).expect("Failed to read line");
42//! query = query.trim().to_string();
43//! println!();
44//!
45//! let recommendations = pass_description(&model, query, tags_input, 10);
46//! match recommendations {
47//! Ok(recommendations) => {
48//! println!("Recommendations:");
49//! for recommendation in recommendations {
50//! println!("{}% {}", (recommendation.1 * 100.0).round(), recommendation.0);
51//! }
52//! },
53//! Err(_) => println!("No recommendations found"),
54//! }
55//! Ok(())
56//!}
57//!
58//! // Examples are also provided in the examples folder of the git repository at https://github.com/jameslk3/reco-forge.
59//!
60//! ```
61//! Required JSON file format:
62//! ```json
63//! [
64//! {
65//! "id": int,
66//! "name": "string",
67//! "summary": "string",
68//! "tags": ["string1", "string2"]
69//! },
70//! {
71//! ...
72//! }
73//! ]
74//! ```
75
76
77pub(crate) mod helpers;
78
79extern crate candle;
80
81pub use candle::Tensor;
82pub use helpers::types::Data;
83pub use std::collections::HashMap;
84
85use helpers::pre_recommendation::{extract_data, insert_embeddings, find_embedding};
86use helpers::recommendation::{get_recommendations, create_input_embedding};
87
88/// # create_model
89/// This function creates the model from the file path given by the user
90///
91/// # Arguments
92/// ```no_run
93/// * file_path: &String - The file path to the model
94/// ```
95///
96/// # Returns
97/// ```no_run
98/// * Result<HashMap<Data, Option<Tensor>>, String> - The model if it was created successfully, otherwise a wrapped error message
99/// ```
100///
101/// # Example
102/// ```no_run
103/// let file_path = "path/to/model".to_string();
104/// let model = create_model(&file_path);
105/// match model {
106/// Ok(model) => println!("Model created successfully"),
107/// Err(e) => println!("Error: {}", e),
108/// }
109/// ```
110pub fn create_model(file_path: &String) -> Result<HashMap<Data, Option<Tensor>>, String> {
111 let nodes_wrapped: Result<HashMap<Data, Option<Tensor>>, ()> = extract_data(&file_path);
112 if nodes_wrapped.is_err() {
113 return Err(("File path is not valid or file cannot be deserialized, please input the correct file path and try again:").to_string());
114 }
115 let mut nodes: HashMap<Data, Option<Tensor>> = nodes_wrapped.unwrap();
116 println!("Creating model, please be patient...");
117 if let Ok(_) = insert_embeddings(&mut nodes) {
118 return Ok(nodes);
119 }
120 return Err("Error inserting embeddings".to_string());
121}
122
123/// # pass_description
124/// This function is used when the user wants to find recommendations based on a description
125///
126/// # Arguments
127/// ```no_run
128/// * node_embeddings: &HashMap<Data, Option<Tensor> - The model
129/// * description_input: String - The description input by the user
130/// * tags_input: String - The tags input by the user, each tag separated by a comma. If the user doesn't want to filter by tags, they can enter NONE
131/// * num_recommendations: usize - The number of recommendations the user wants
132/// ```
133///
134/// # Returns
135/// ```no_run
136/// * Result<Vec<String, f32>, ()> - A vector of (Item name, similarity) tuples if recommendations were found, otherwise Err
137/// ```
138///
139/// # Example
140/// ```no_run
141/// let recommendations = pass_description(&model, "description".to_string(), "tag1,tag2".to_string(), 10);
142/// match recommendations {
143/// Ok(recommendations) => {
144/// println!("Recommendations:");
145/// for recommendation in recommendations {
146/// println!("{}% {}", (recommendation.1 * 100.0).round(), recommendation.0);
147/// }
148/// },
149/// Err(_) => println!("No recommendations found"),
150/// }
151/// ```
152pub fn pass_description(node_embeddings: &HashMap<Data, Option<Tensor>>, description_input: String, tags_input: String, num_recommendations: usize) -> Result<Vec<(String, f32)>, ()> {
153
154 // When we are given a description, we need to create an embedding for it and then find recommendations based on that
155 let input_embedding = create_input_embedding(&description_input).unwrap().unwrap();
156 let recommendations = get_recommendations(node_embeddings, None, &input_embedding, &tags_input, num_recommendations);
157 recommendations
158}
159
160/// # pass_item
161/// This function is used when the user wants to find recommendations based on a specific item that is already in the model
162///
163/// # Arguments
164/// ```no_run
165/// * node_embeddings: &HashMap<Data, Option<Tensor> - The model
166/// * item: String - The item the user wants recommendations for
167/// * tags_input: String - The tags input by the user, each tag separated by a comma. If the user doesn't want to filter by tags, they can enter NONE
168/// * num_recommendations: usize - The number of recommendations the user wants
169/// ```
170///
171/// # Returns
172/// ```no_run
173/// * Result<Vec<String, f32>, ()> - A vector of (Item name, similarity) tuples if recommendations were found, otherwise Err
174/// ```
175///
176/// # Example
177/// ```no_run
178/// let recommendations = pass_item(&model, "item".to_string(), "tag1,tag2".to_string(), 10);
179/// match recommendations {
180/// Ok(recommendations) => {
181/// println!("Recommendations:");
182/// for recommendation in recommendations {
183/// println!("{}% {}", (recommendation.1 * 100.0).round(), recommendation.0);
184/// }
185/// },
186/// Err(_) => println!("No recommendations found"),
187/// }
188/// ```
189pub fn pass_item(node_embeddings: &HashMap<Data, Option<Tensor>>, item: String, tags_input: String, num_recommendations: usize) -> Result<Vec<(String, f32)>, ()> {
190
191 // When we want to find items similar to a specific item, we need to make sure that the item is in the embeddings and then retrieve the embedding
192 let input_embedding = {
193 if let Ok(embedding) = find_embedding(node_embeddings, &item) {
194 embedding
195 } else {
196 return Err(());
197 }
198 };
199 let recommendations = get_recommendations(node_embeddings, Some(&item), &input_embedding, &tags_input, num_recommendations);
200 recommendations
201}