Skip to main content

sqlite_knowledge_graph/
extension.rs

1//! SQLite extension entry point using sqlite-loadable
2//!
3//! This module provides the SQLite loadable extension interface.
4
5use sqlite_loadable::{
6    define_scalar_function, ext::sqlite3ext_result_text, prelude::*, Error, FunctionFlags,
7};
8use std::ffi::CString;
9
10/// Helper function to return text result
11fn result_text(context: *mut sqlite3_context, text: &str) {
12    let cstr = CString::new(text).unwrap();
13    unsafe {
14        sqlite3ext_result_text(
15            context,
16            cstr.as_ptr(),
17            cstr.as_bytes().len() as i32,
18            Some(std::mem::transmute::<
19                i64,
20                unsafe extern "C" fn(*mut std::ffi::c_void),
21            >(-1i64)),
22        );
23    }
24}
25
26/// kg_version() - Returns the extension version
27pub fn kg_version(
28    context: *mut sqlite3_context,
29    _values: &[*mut sqlite3_value],
30) -> Result<(), Error> {
31    result_text(context, env!("CARGO_PKG_VERSION"));
32    Ok(())
33}
34
35/// kg_stats() - Returns graph statistics as JSON
36pub fn kg_stats(
37    context: *mut sqlite3_context,
38    _values: &[*mut sqlite3_value],
39) -> Result<(), Error> {
40    // For now, return a simple message indicating the extension is loaded
41    // Full implementation would require accessing the database connection
42    result_text(
43        context,
44        "{\"status\": \"Extension loaded - use KnowledgeGraph API for full stats\"}",
45    );
46    Ok(())
47}
48
49/// kg_pagerank() - Compute PageRank scores for all entities
50/// Parameters: damping (REAL, default 0.85), max_iterations (INTEGER, default 100), tolerance (REAL, default 1e-6)
51/// Returns JSON with algorithm info
52pub fn kg_pagerank(
53    context: *mut sqlite3_context,
54    values: &[*mut sqlite3_value],
55) -> Result<(), Error> {
56    // Parse optional damping parameter (default 0.85)
57    let damping = if !values.is_empty() {
58        unsafe { sqlite_loadable::ext::sqlite3ext_value_double(values[0]) }
59    } else {
60        0.85
61    };
62
63    // Parse optional max_iterations parameter (default 100)
64    let max_iterations = if values.len() >= 2 {
65        unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[1]) as usize }
66    } else {
67        100
68    };
69
70    // Parse optional tolerance parameter (default 1e-6)
71    let tolerance = if values.len() >= 3 {
72        unsafe { sqlite_loadable::ext::sqlite3ext_value_double(values[2]) }
73    } else {
74        1e-6
75    };
76
77    // Return configuration info - actual computation requires database access
78    let result = format!(
79        "{{\"algorithm\": \"pagerank\", \"damping\": {}, \"max_iterations\": {}, \"tolerance\": {}, \"note\": \"Use KnowledgeGraph::kg_pagerank() for full computation\"}}",
80        damping, max_iterations, tolerance
81    );
82    result_text(context, &result);
83    Ok(())
84}
85
86/// kg_louvain() - Detect communities using Louvain algorithm
87/// Returns JSON with community memberships and modularity score
88pub fn kg_louvain(
89    context: *mut sqlite3_context,
90    _values: &[*mut sqlite3_value],
91) -> Result<(), Error> {
92    result_text(context, "{\"algorithm\": \"louvain\", \"note\": \"Use KnowledgeGraph::kg_louvain() for full computation\"}");
93    Ok(())
94}
95
96/// kg_bfs() - BFS traversal from a starting entity
97/// Parameters: start_id (INTEGER), max_depth (INTEGER, default 3)
98/// Returns JSON array of {entity_id, depth} objects
99pub fn kg_bfs(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<(), Error> {
100    if values.is_empty() {
101        return Err(Error::new_message(
102            "kg_bfs requires at least 1 argument: start_id",
103        ));
104    }
105
106    let start_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
107    let max_depth = if values.len() >= 2 {
108        unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[1]) as u32 }
109    } else {
110        3
111    };
112
113    let result = format!(
114        "{{\"algorithm\": \"bfs\", \"start_id\": {}, \"max_depth\": {}, \"note\": \"Use KnowledgeGraph::kg_bfs_traversal() for full computation\"}}",
115        start_id, max_depth
116    );
117    result_text(context, &result);
118    Ok(())
119}
120
121/// kg_shortest_path() - Find shortest path between two entities
122/// Parameters: from_id (INTEGER), to_id (INTEGER), max_depth (INTEGER, default 10)
123/// Returns JSON array of entity IDs representing the path
124pub fn kg_shortest_path(
125    context: *mut sqlite3_context,
126    values: &[*mut sqlite3_value],
127) -> Result<(), Error> {
128    if values.len() < 2 {
129        return Err(Error::new_message(
130            "kg_shortest_path requires at least 2 arguments: from_id, to_id",
131        ));
132    }
133
134    let from_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
135    let to_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[1]) };
136    let max_depth = if values.len() >= 3 {
137        unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[2]) as u32 }
138    } else {
139        10
140    };
141
142    let result = format!(
143        "{{\"algorithm\": \"shortest_path\", \"from_id\": {}, \"to_id\": {}, \"max_depth\": {}, \"note\": \"Use KnowledgeGraph::kg_shortest_path() for full computation\"}}",
144        from_id, to_id, max_depth
145    );
146    result_text(context, &result);
147    Ok(())
148}
149
150/// kg_connected_components() - Find connected components in the graph
151/// Returns JSON with component information
152pub fn kg_connected_components(
153    context: *mut sqlite3_context,
154    _values: &[*mut sqlite3_value],
155) -> Result<(), Error> {
156    result_text(context, "{\"algorithm\": \"connected_components\", \"note\": \"Use KnowledgeGraph::kg_connected_components() for full computation\"}");
157    Ok(())
158}
159
160/// Register functions
161fn register_extension_functions(db: *mut sqlite3) -> Result<(), Error> {
162    let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC;
163
164    // Basic info functions
165    define_scalar_function(db, "kg_version", 0, kg_version, flags)?;
166    define_scalar_function(db, "kg_stats", 0, kg_stats, flags)?;
167
168    // Graph algorithm functions with optional parameters
169    define_scalar_function(db, "kg_pagerank", 0, kg_pagerank, flags)?;
170    define_scalar_function(db, "kg_pagerank", 1, kg_pagerank, flags)?;
171    define_scalar_function(db, "kg_pagerank", 2, kg_pagerank, flags)?;
172    define_scalar_function(db, "kg_pagerank", 3, kg_pagerank, flags)?;
173
174    define_scalar_function(db, "kg_louvain", 0, kg_louvain, flags)?;
175
176    define_scalar_function(db, "kg_bfs", 1, kg_bfs, flags)?;
177    define_scalar_function(db, "kg_bfs", 2, kg_bfs, flags)?;
178
179    define_scalar_function(db, "kg_shortest_path", 2, kg_shortest_path, flags)?;
180    define_scalar_function(db, "kg_shortest_path", 3, kg_shortest_path, flags)?;
181
182    define_scalar_function(
183        db,
184        "kg_connected_components",
185        0,
186        kg_connected_components,
187        flags,
188    )?;
189
190    Ok(())
191}
192
193/// Extension entry point
194#[sqlite_entrypoint]
195pub fn sqlite3_sqlite_knowledge_graph_init(db: *mut sqlite3) -> Result<(), Error> {
196    register_extension_functions(db)
197}
198
199#[cfg(test)]
200mod tests {
201    #[test]
202    fn test_kg_version_format() {
203        // Verify version is in expected format (x.y.z)
204        let version = env!("CARGO_PKG_VERSION");
205        assert!(!version.is_empty());
206        assert!(version.contains('.'));
207    }
208}