use sqlite_loadable::{
define_scalar_function, ext::sqlite3ext_result_text, prelude::*, Error, FunctionFlags,
};
use std::ffi::CString;
fn result_text(context: *mut sqlite3_context, text: &str) {
let cstr = CString::new(text).unwrap();
unsafe {
sqlite3ext_result_text(
context,
cstr.as_ptr(),
cstr.as_bytes().len() as i32,
Some(std::mem::transmute::<
i64,
unsafe extern "C" fn(*mut std::ffi::c_void),
>(-1i64)),
);
}
}
pub fn kg_version(
context: *mut sqlite3_context,
_values: &[*mut sqlite3_value],
) -> Result<(), Error> {
result_text(context, env!("CARGO_PKG_VERSION"));
Ok(())
}
pub fn kg_stats(
context: *mut sqlite3_context,
_values: &[*mut sqlite3_value],
) -> Result<(), Error> {
result_text(
context,
"{\"status\": \"Extension loaded - use KnowledgeGraph API for full stats\"}",
);
Ok(())
}
pub fn kg_pagerank(
context: *mut sqlite3_context,
values: &[*mut sqlite3_value],
) -> Result<(), Error> {
let damping = if !values.is_empty() {
unsafe { sqlite_loadable::ext::sqlite3ext_value_double(values[0]) }
} else {
0.85
};
let max_iterations = if values.len() >= 2 {
unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[1]) as usize }
} else {
100
};
let tolerance = if values.len() >= 3 {
unsafe { sqlite_loadable::ext::sqlite3ext_value_double(values[2]) }
} else {
1e-6
};
let result = format!(
"{{\"algorithm\": \"pagerank\", \"damping\": {}, \"max_iterations\": {}, \"tolerance\": {}, \"note\": \"Use KnowledgeGraph::kg_pagerank() for full computation\"}}",
damping, max_iterations, tolerance
);
result_text(context, &result);
Ok(())
}
pub fn kg_louvain(
context: *mut sqlite3_context,
_values: &[*mut sqlite3_value],
) -> Result<(), Error> {
result_text(context, "{\"algorithm\": \"louvain\", \"note\": \"Use KnowledgeGraph::kg_louvain() for full computation\"}");
Ok(())
}
pub fn kg_bfs(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<(), Error> {
if values.is_empty() {
return Err(Error::new_message(
"kg_bfs requires at least 1 argument: start_id",
));
}
let start_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
let max_depth = if values.len() >= 2 {
unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[1]) as u32 }
} else {
3
};
let result = format!(
"{{\"algorithm\": \"bfs\", \"start_id\": {}, \"max_depth\": {}, \"note\": \"Use KnowledgeGraph::kg_bfs_traversal() for full computation\"}}",
start_id, max_depth
);
result_text(context, &result);
Ok(())
}
pub fn kg_shortest_path(
context: *mut sqlite3_context,
values: &[*mut sqlite3_value],
) -> Result<(), Error> {
if values.len() < 2 {
return Err(Error::new_message(
"kg_shortest_path requires at least 2 arguments: from_id, to_id",
));
}
let from_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
let to_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[1]) };
let max_depth = if values.len() >= 3 {
unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[2]) as u32 }
} else {
10
};
let result = format!(
"{{\"algorithm\": \"shortest_path\", \"from_id\": {}, \"to_id\": {}, \"max_depth\": {}, \"note\": \"Use KnowledgeGraph::kg_shortest_path() for full computation\"}}",
from_id, to_id, max_depth
);
result_text(context, &result);
Ok(())
}
pub fn kg_connected_components(
context: *mut sqlite3_context,
_values: &[*mut sqlite3_value],
) -> Result<(), Error> {
result_text(context, "{\"algorithm\": \"connected_components\", \"note\": \"Use KnowledgeGraph::kg_connected_components() for full computation\"}");
Ok(())
}
fn register_extension_functions(db: *mut sqlite3) -> Result<(), Error> {
let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC;
define_scalar_function(db, "kg_version", 0, kg_version, flags)?;
define_scalar_function(db, "kg_stats", 0, kg_stats, flags)?;
define_scalar_function(db, "kg_pagerank", 0, kg_pagerank, flags)?;
define_scalar_function(db, "kg_pagerank", 1, kg_pagerank, flags)?;
define_scalar_function(db, "kg_pagerank", 2, kg_pagerank, flags)?;
define_scalar_function(db, "kg_pagerank", 3, kg_pagerank, flags)?;
define_scalar_function(db, "kg_louvain", 0, kg_louvain, flags)?;
define_scalar_function(db, "kg_bfs", 1, kg_bfs, flags)?;
define_scalar_function(db, "kg_bfs", 2, kg_bfs, flags)?;
define_scalar_function(db, "kg_shortest_path", 2, kg_shortest_path, flags)?;
define_scalar_function(db, "kg_shortest_path", 3, kg_shortest_path, flags)?;
define_scalar_function(
db,
"kg_connected_components",
0,
kg_connected_components,
flags,
)?;
Ok(())
}
#[sqlite_entrypoint]
pub fn sqlite3_sqlite_knowledge_graph_init(db: *mut sqlite3) -> Result<(), Error> {
register_extension_functions(db)
}
#[cfg(test)]
mod tests {
#[test]
fn test_kg_version_format() {
let version = env!("CARGO_PKG_VERSION");
assert!(!version.is_empty());
assert!(version.contains('.'));
}
}