sqlite_knowledge_graph/
functions.rs1use crate::error::Result;
8
9pub fn register_functions(conn: &rusqlite::Connection) -> Result<()> {
11 conn.create_scalar_function(
13 "kg_cosine_similarity",
14 2,
15 rusqlite::functions::FunctionFlags::SQLITE_UTF8,
16 |ctx| {
17 let vec1_blob: Vec<u8> = ctx.get(0)?;
18 let vec2_blob: Vec<u8> = ctx.get(1)?;
19
20 let mut vec1 = Vec::new();
22 for chunk in vec1_blob.chunks_exact(4) {
23 let bytes: [u8; 4] = match chunk.try_into() {
24 Ok(b) => b,
25 Err(_) => return Ok(0.0f64),
26 };
27 vec1.push(f32::from_le_bytes(bytes));
28 }
29
30 let mut vec2 = Vec::new();
31 for chunk in vec2_blob.chunks_exact(4) {
32 let bytes: [u8; 4] = match chunk.try_into() {
33 Ok(b) => b,
34 Err(_) => return Ok(0.0f64),
35 };
36 vec2.push(f32::from_le_bytes(bytes));
37 }
38
39 if vec1.len() != vec2.len() {
40 return Ok(0.0f64);
41 }
42
43 let mut dot_product = 0.0_f32;
44 let mut norm_a = 0.0_f32;
45 let mut norm_b = 0.0_f32;
46
47 for i in 0..vec1.len() {
48 dot_product += vec1[i] * vec2[i];
49 norm_a += vec1[i] * vec1[i];
50 norm_b += vec2[i] * vec2[i];
51 }
52
53 if norm_a == 0.0 || norm_b == 0.0 {
54 return Ok(0.0f64);
55 }
56
57 let similarity = dot_product / (norm_a.sqrt() * norm_b.sqrt());
58 Ok(similarity as f64)
59 },
60 )?;
61
62 Ok(())
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use rusqlite::{params, Connection};
69
70 #[test]
71 fn test_register_functions() {
72 let conn = Connection::open_in_memory().unwrap();
73 crate::schema::create_schema(&conn).unwrap();
74
75 assert!(register_functions(&conn).is_ok());
77
78 let mut vec1: Vec<u8> = Vec::new();
80 vec1.extend_from_slice(&1.0_f32.to_le_bytes());
81 vec1.extend_from_slice(&0.0_f32.to_le_bytes());
82 vec1.extend_from_slice(&0.0_f32.to_le_bytes());
83 let vec2 = vec1.clone();
84
85 let sim: f64 = conn
86 .query_row(
87 "SELECT kg_cosine_similarity(?1, ?2)",
88 params![vec1, vec2],
89 |row| row.get(0),
90 )
91 .unwrap();
92 assert!((sim - 1.0).abs() < 0.001);
93 }
94}