use anyhow::{Result, bail};
use futures::future::join_all;
use log::info;
use std::{any::Any, sync::Arc};
use crate::{Command, cache::Cache, embedder::Embedder, input::Input};
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
(dot_product / (magnitude_a * magnitude_b)).clamp(0.0, 1.0)
}
pub struct SemanticCommands<E: Embedder, Ch: Cache, C> {
embedder: Arc<E>,
cache: Arc<Ch>,
context: Arc<C>,
threshold: f32,
entries: Vec<(Vec<Input>, Command<C>)>,
}
impl<E: Embedder, Ch: Cache, C> SemanticCommands<E, Ch, C> {
pub async fn get_embedding(&self, input: &str) -> Result<Vec<f32>> {
match self.cache.get(input).await? {
Some(embedding) => Ok(embedding),
None => {
info!("embedding not found in cache, generating new one");
let embedding = self.embedder.as_ref().embed(input).await?;
self.cache.put(input, embedding.clone()).await?;
Ok(embedding)
}
}
}
pub fn new(embedder: E, cache: Ch, context: C) -> Self {
Self {
embedder: Arc::new(embedder),
cache: Arc::new(cache),
context: Arc::new(context),
threshold: 0.8,
entries: vec![],
}
}
pub fn threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
async fn find_similar(&mut self, embedding: Vec<f32>, threshold: f32) -> Result<Option<(&Input, &Command<C>)>> {
let missing_embeddings: Vec<_> = self
.entries
.iter()
.flat_map(|(inputs, _)| inputs)
.filter(|input| input.embedding.is_none())
.map(|input| input.text.clone())
.collect();
let embeddings: Vec<_> = join_all(missing_embeddings.iter().map(|text| async { self.get_embedding(text).await }))
.await
.into_iter()
.filter_map(Result::ok)
.collect();
let mut emb_iter = embeddings.into_iter();
for (inputs, _) in &mut self.entries {
for input in inputs {
if input.embedding.is_none() {
input.embedding = emb_iter.next();
}
}
}
let res = self
.entries
.iter()
.flat_map(|(inputs, command)| {
let emb = embedding.clone();
inputs.iter().filter_map(move |input| {
let similarity = cosine_similarity(&emb, input.embedding.as_ref()?);
(similarity >= threshold).then_some((similarity, input, command))
})
})
.collect::<Vec<_>>();
Ok(res
.into_iter()
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.map(|(_similarity, input, command)| (input, command)))
}
pub async fn execute(&mut self, input: &str) -> Result<Box<dyn Any + Send>> {
let input_embedding = self.get_embedding(input).await?;
let context = self.context.clone();
let similar = self.find_similar(input_embedding, self.threshold).await?;
match similar {
Some((_input, command)) => {
info!("command recognized as: {:?}", command.name);
let result = (command.executor)(context).await;
Ok(result)
}
None => {
bail!("no similar command found");
}
}
}
pub fn add_command(&mut self, command: Command<C>, inputs: Vec<Input>) -> &mut Self {
self.entries.push((inputs, command));
self
}
pub fn add_commands(&mut self, commands: Vec<(Command<C>, Vec<Input>)>) -> &mut Self {
commands.into_iter().for_each(|(command, inputs)| {
self.entries.push((inputs, command));
});
self
}
pub async fn init(&mut self) -> Result<&mut Self> {
self.cache.init().await?;
Ok(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let c = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
assert!((cosine_similarity(&a, &c) - 1.0).abs() < 1e-6);
}
}