sqlite_knowledge_graph/
functions.rs1use crate::error::Result;
8
9pub fn register_functions(conn: &rusqlite::Connection) -> Result<()> {
11 conn.create_scalar_function(
13 "kg_cosine_similarity",
14 2,
15 rusqlite::functions::FunctionFlags::SQLITE_UTF8,
16 |ctx| {
17 let vec1_blob: Vec<u8> = ctx.get(0)?;
18 let vec2_blob: Vec<u8> = ctx.get(1)?;
19
20 let mut vec1 = Vec::new();
22 for chunk in vec1_blob.chunks_exact(4) {
23 let bytes: [u8; 4] = match chunk.try_into() {
24 Ok(b) => b,
25 Err(_) => return Ok(0.0f64),
26 };
27 vec1.push(f32::from_le_bytes(bytes));
28 }
29
30 let mut vec2 = Vec::new();
31 for chunk in vec2_blob.chunks_exact(4) {
32 let bytes: [u8; 4] = match chunk.try_into() {
33 Ok(b) => b,
34 Err(_) => return Ok(0.0f64),
35 };
36 vec2.push(f32::from_le_bytes(bytes));
37 }
38
39 if vec1.len() != vec2.len() {
40 return Ok(0.0f64);
41 }
42
43 let mut dot_product = 0.0_f32;
44 let mut norm_a = 0.0_f32;
45 let mut norm_b = 0.0_f32;
46
47 for i in 0..vec1.len() {
48 dot_product += vec1[i] * vec2[i];
49 norm_a += vec1[i] * vec1[i];
50 norm_b += vec2[i] * vec2[i];
51 }
52
53 if norm_a == 0.0 || norm_b == 0.0 {
54 return Ok(0.0f64);
55 }
56
57 let similarity = dot_product / (norm_a.sqrt() * norm_b.sqrt());
58 Ok(similarity as f64)
59 },
60 )?;
61
62 conn.create_scalar_function(
64 "kg_bit_count",
65 1,
66 rusqlite::functions::FunctionFlags::SQLITE_UTF8,
67 |ctx| {
68 let val: Option<i64> = ctx.get(0)?;
69 match val {
70 Some(x) => Ok(x.count_ones() as i64),
71 None => Ok(0),
72 }
73 },
74 )?;
75
76 Ok(())
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use rusqlite::{params, Connection};
83
84 #[test]
85 fn test_register_functions() {
86 let conn = Connection::open_in_memory().unwrap();
87 crate::schema::create_schema(&conn).unwrap();
88
89 assert!(register_functions(&conn).is_ok());
91
92 let mut vec1: Vec<u8> = Vec::new();
94 vec1.extend_from_slice(&1.0_f32.to_le_bytes());
95 vec1.extend_from_slice(&0.0_f32.to_le_bytes());
96 vec1.extend_from_slice(&0.0_f32.to_le_bytes());
97 let vec2 = vec1.clone();
98
99 let sim: f64 = conn
100 .query_row(
101 "SELECT kg_cosine_similarity(?1, ?2)",
102 params![vec1, vec2],
103 |row| row.get(0),
104 )
105 .unwrap();
106 assert!((sim - 1.0).abs() < 0.001);
107 }
108
109 #[test]
110 fn test_kg_bit_count_positive() {
111 let conn = Connection::open_in_memory().unwrap();
112 crate::schema::create_schema(&conn).unwrap();
113 register_functions(&conn).unwrap();
114
115 let count: i64 = conn
116 .query_row("SELECT kg_bit_count(7)", [], |r| r.get(0))
117 .unwrap();
118 assert_eq!(count, 3); }
120
121 #[test]
122 fn test_kg_bit_count_zero() {
123 let conn = Connection::open_in_memory().unwrap();
124 crate::schema::create_schema(&conn).unwrap();
125 register_functions(&conn).unwrap();
126
127 let count: i64 = conn
128 .query_row("SELECT kg_bit_count(0)", [], |r| r.get(0))
129 .unwrap();
130 assert_eq!(count, 0);
131 }
132
133 #[test]
134 fn test_kg_bit_count_null() {
135 let conn = Connection::open_in_memory().unwrap();
136 crate::schema::create_schema(&conn).unwrap();
137 register_functions(&conn).unwrap();
138
139 let count: i64 = conn
140 .query_row("SELECT kg_bit_count(NULL)", [], |r| r.get(0))
141 .unwrap();
142 assert_eq!(count, 0);
143 }
144}