rag/
rag.rs

1use espionox::{
2    agents::memory::ToMessage,
3    language_models::embeddings::{error::EmbeddingError, EmbeddingModel},
4    prelude::*,
5};
6use std::collections::HashMap;
7
8#[derive(Debug)]
9pub struct RagManager<'p> {
10    /// RAG listeners can be used as long as they have some connection to a data source. In this
11    /// example we use a vector, but it could be anything, including a Database pool.
12    data: Option<DbStruct<'p>>,
13    /// It depends on your implementation and Data source, but in this example, our RAG listener
14    /// will require access to an embedderr
15    embedder: EmbeddingModel,
16}
17
18#[derive(Debug, Clone)]
19pub struct Product<'p> {
20    name: &'p str,
21    description: &'p str,
22    desc_embedding: Vec<f32>,
23}
24
25#[derive(Debug)]
26pub struct DbStruct<'p>(Vec<Product<'p>>);
27
28/// We'll define a distance score function. In production your database will likely have something
29/// that does this
30fn score_l2(one: &Vec<f32>, other: &Vec<f32>) -> f32 {
31    let sum_of_squares: f32 = one
32        .iter()
33        .zip(other.iter())
34        .map(|(&x, &y)| (x - y).powi(2))
35        .sum();
36
37    sum_of_squares.sqrt()
38}
39
40impl<'p> RagManager<'p> {
41    async fn embed(&mut self, str: &str) -> Result<Vec<f32>, EmbeddingError> {
42        let embedding: Vec<f32> = self.embedder.get_embedding(str).await?;
43        Ok(embedding)
44    }
45
46    async fn init_products(&mut self) {
47        let data = DbStruct(vec![
48        Product {
49            name: "SmartWatch 2000",
50            description: "Stay connected and track your fitness with the SmartWatch 2000. This sleek device features a vibrant touchscreen, heart rate monitoring, and a variety of smart notifications.",
51            desc_embedding: self.embed("Stay connected and track your fitness with the SmartWatch 2000. This sleek device features a vibrant touchscreen, heart rate monitoring, and a variety of smart notifications.").await.unwrap(),
52        },
53
54        Product {
55            name: "Quantum Laptop Pro",
56            description: "Unleash the power of productivity with the Quantum Laptop Pro. Equipped with a high-performance processor, stunning display, and a lightweight design, it's your perfect companion for work and play.",
57            desc_embedding: self.embed("Unleash the power of productivity with the Quantum Laptop Pro. Equipped with a high-performance processor, stunning display, and a lightweight design, it's your perfect companion for work and play.").await.unwrap(),
58        },
59
60        Product {
61            name: "ZenAir Noise-Canceling Headphones",
62            description: "Immerse yourself in crystal-clear sound with the ZenAir Noise-Canceling Headphones. These wireless over-ear headphones offer premium comfort and cutting-edge noise-canceling technology for an unparalleled audio experience.",
63            desc_embedding: self.embed("Immerse yourself in crystal-clear sound with the ZenAir Noise-Canceling Headphones. These wireless over-ear headphones offer premium comfort and cutting-edge noise-canceling technology for an unparalleled audio experience.").await.unwrap(),
64        },
65
66        Product {
67            name: "Eco-Friendly Bamboo Water Bottle",
68            description: "Make a statement while staying eco-friendly with our Bamboo Water Bottle. Crafted from sustainable bamboo, this stylish and reusable bottle is perfect for staying hydrated on the go.",
69            desc_embedding: self.embed("Make a statement while staying eco-friendly with our Bamboo Water Bottle. Crafted from sustainable bamboo, this stylish and reusable bottle is perfect for staying hydrated on the go.").await.unwrap(),
70        },
71
72        Product {
73            name: "Stellar Telescope 4000X",
74            description: "Explore the wonders of the night sky with the Stellar Telescope 4000X. This high-powered telescope is perfect for astronomy enthusiasts, featuring advanced optics and a sturdy mount for clear and detailed views.",
75            desc_embedding: self.embed("Explore the wonders of the night sky with the Stellar Telescope 4000X. This high-powered telescope is perfect for astronomy enthusiasts, featuring advanced optics and a sturdy mount for clear and detailed views.").await.unwrap(),
76        },
77
78        Product {
79            name: "Gourmet Coffee Sampler Pack",
80            description: "Indulge your taste buds with our Gourmet Coffee Sampler Pack. This curated collection includes a variety of premium coffee blends from around the world, offering a delightful coffee experience.",
81            desc_embedding: self.embed("Indulge your taste buds with our Gourmet Coffee Sampler Pack. This curated collection includes a variety of premium coffee blends from around the world, offering a delightful coffee experience.").await.unwrap(),
82        },
83
84        Product {
85            name: "Fitness Tracker Pro",
86            description: "Achieve your fitness goals with the Fitness Tracker Pro. Monitor your steps, heart rate, and sleep patterns while receiving real-time notifications. Sleek design and long battery life make it an essential companion for an active lifestyle.",
87            desc_embedding: self.embed("Achieve your fitness goals with the Fitness Tracker Pro. Monitor your steps, heart rate, and sleep patterns while receiving real-time notifications. Sleek design and long battery life make it an essential companion for an active lifestyle.").await.unwrap(),
88        },
89
90        Product {
91            name: "Retro Arcade Gaming Console",
92            description: "Relive the nostalgia of classic arcade games with our Retro Arcade Gaming Console. Packed with your favorite titles, this compact console brings back the joy of retro gaming in a modern and portable design.",
93            desc_embedding: self.embed("Relive the nostalgia of classic arcade games with our Retro Arcade Gaming Console. Packed with your favorite titles, this compact console brings back the joy of retro gaming in a modern and portable design.").await.unwrap(),
94        },
95
96        Product {
97            name: "Luxe Leather Messenger Bag",
98            description: "Elevate your style with the Luxe Leather Messenger Bag. Crafted from premium leather, this sophisticated bag combines fashion and functionality, offering ample space for your essentials in a timeless design.",
99            desc_embedding: self.embed("Elevate your style with the Luxe Leather Messenger Bag. Crafted from premium leather, this sophisticated bag combines fashion and functionality, offering ample space for your essentials in a timeless design.").await.unwrap(),
100        },
101
102       Product {
103            name: "Herbal Infusion Tea Set",
104            description: "Unwind and savor the soothing flavors of our Herbal Infusion Tea Set. This carefully curated collection features a blend of herbal teas, each with unique health benefits and delightful aromas.",
105            desc_embedding: self.embed("Unwind and savor the soothing flavors of our Herbal Infusion Tea Set. This carefully curated collection features a blend of herbal teas, each with unique health benefits and delightful aromas.").await.unwrap(),
106        }
107    ]);
108        self.data = Some(data);
109    }
110}
111
112/// We'll implement ToMessage for our DbStruct so we have control over how the model sees the data
113/// it's given
114impl<'p> ToMessage for DbStruct<'p> {
115    fn to_message(&self, role: MessageRole) -> Message {
116        let mut content = String::from("Answer the user's query based on the provided data:");
117        self.0.iter().for_each(|p| {
118            content.push_str(&format!(
119                "\nProduct Name: {}\nProduct Description: {}",
120                p.name, p.description
121            ));
122        });
123        Message { role, content }
124    }
125}
126
127impl<'p> DbStruct<'p> {
128    /// A simple helper function to get similar data given a query embedding
129    /// KEEP IN MIND THIS IS JUST FOR AN EXAMPLE, I DO NOT RECCOMEND VECTOR QUERYING AN ARRAY LIKE
130    /// THIS IN PROD
131    async fn get_close_embeddings_from_query_embedding(
132        &self,
133        qembed: Vec<f32>,
134        amt: usize,
135    ) -> DbStruct<'p> {
136        let mut map = HashMap::new();
137        let mut scores: Vec<f32> = self
138            .0
139            .iter()
140            .map(|p| {
141                let score = score_l2(&qembed, &p.desc_embedding);
142                map.insert((score * 100.0) as u32, p);
143                println!("Score for: {} is {}", p.name, score);
144                score
145            })
146            .collect();
147        scores.sort_by(|a, b| a.total_cmp(b));
148        let closest = scores[..amt].into_iter().fold(vec![], |mut acc, s| {
149            let score_key = (s * 100.0) as u32;
150            if let Some(val) = map.remove(&score_key) {
151                acc.push(val.to_owned())
152            }
153            acc
154        });
155        DbStruct(closest)
156    }
157}
158
159#[tokio::main]
160async fn main() {
161    dotenv::dotenv().ok();
162    let api_key = std::env::var("OPENAI_KEY").unwrap();
163    let embedder = EmbeddingModel::default_openai(&api_key);
164    let mut agent = Agent::new(
165        Some("You are jerry!!"),
166        CompletionModel::default_openai(&api_key),
167    );
168
169    let mut rag = RagManager {
170        embedder,
171        data: None,
172    };
173
174    rag.init_products().await;
175    // agent.insert_listener(listener);
176
177    let m = Message::new_user("I need a new fitness toy, what is the best product for me?");
178    let message_embedding = rag
179        .embed(&m.content)
180        .await
181        .expect("Failed to embed message content");
182    let relavent = rag
183        .data
184        .as_ref()
185        .unwrap()
186        .get_close_embeddings_from_query_embedding(message_embedding, 5)
187        .await;
188    println!("Got relavent structs: {relavent:#?}");
189    agent.cache.push(m);
190    agent.cache.push(relavent.to_message(MessageRole::User));
191    let response = agent.io_completion().await.unwrap();
192    println!("{:?}", response);
193}