use pyo3::exceptions::{PyOSError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::types::PyDict;
use std::collections::HashMap;
use std::path::PathBuf;
use crate::{Graph, StorageManager, Value};
fn value_to_py(py: Python<'_>, v: &Value) -> PyObject {
match v {
Value::Null => py.None(),
Value::Bool(b) => b.into_py(py),
Value::Int(i) => i.into_py(py),
Value::Float(f) => f.into_py(py),
Value::String(s) => s.into_py(py),
Value::List(items) => items
.iter()
.map(|x| value_to_py(py, x))
.collect::<Vec<_>>()
.into_py(py),
Value::Map(m) => {
let d = PyDict::new_bound(py);
for (k, val) in m {
d.set_item(k, value_to_py(py, val)).unwrap();
}
d.into_py(py)
}
}
}
fn py_to_value(py: Python<'_>, obj: &PyObject) -> Value {
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString};
let bound = obj.bind(py);
if bound.is_none() {
return Value::Null;
}
if let Ok(b) = bound.downcast::<PyBool>() {
return Value::Bool(b.is_true());
}
if let Ok(i) = bound.downcast::<PyInt>() {
if let Ok(n) = i.extract::<i64>() {
return Value::Int(n);
}
}
if let Ok(f) = bound.downcast::<PyFloat>() {
return Value::Float(f.value());
}
if let Ok(s) = bound.downcast::<PyString>() {
return Value::String(s.extract::<String>().unwrap_or_default());
}
if let Ok(lst) = bound.downcast::<PyList>() {
let items: Vec<Value> = lst.iter()
.map(|x| py_to_value(py, &x.into_py(py)))
.collect();
return Value::List(items);
}
if let Ok(d) = bound.downcast::<PyDict>() {
let mut map = std::collections::HashMap::new();
for (k, v) in d.iter() {
let key = k.str().and_then(|s| s.extract::<String>()).unwrap_or_default();
map.insert(key, py_to_value(py, &v.into_py(py)));
}
return Value::Map(map);
}
Value::String(obj.bind(py).str().and_then(|s| s.extract::<String>()).unwrap_or_default())
}
fn db_err_to_py(e: crate::types::DbError) -> PyErr {
use crate::types::DbError::*;
match e {
Storage(io) => PyOSError::new_err(io.to_string()),
_ => PyRuntimeError::new_err(e.to_string()),
}
}
fn rows_to_columnar(
py: Python<'_>,
rows: &[HashMap<String, Value>],
) -> PyResult<PyObject> {
let mut col_order: Vec<String> = Vec::new();
let mut col_seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
for row in rows {
for k in row.keys() {
if col_seen.insert(k.as_str()) {
col_order.push(k.clone());
}
}
}
let d = PyDict::new_bound(py);
for col in &col_order {
let vals: Vec<PyObject> = rows
.iter()
.map(|row| {
row.get(col.as_str())
.map_or_else(|| py.None(), |v| value_to_py(py, v))
})
.collect();
d.set_item(col, vals)?;
}
Ok(d.into_py(py))
}
#[pyclass]
pub struct MiniGdb {
graph: Option<Graph>,
storage: Option<StorageManager>,
next_txn_id: u64,
}
#[pymethods]
impl MiniGdb {
fn query(&mut self, py: Python<'_>, gql: &str) -> PyResult<Vec<PyObject>> {
let graph = self.graph.as_mut()
.ok_or_else(|| PyRuntimeError::new_err("database is closed"))?;
let (rows, _ops) = crate::query_capturing(gql, graph, &mut self.next_txn_id)
.map_err(db_err_to_py)?;
let py_rows = rows
.into_iter()
.map(|row| {
let d = PyDict::new_bound(py);
for (k, v) in &row {
d.set_item(k, value_to_py(py, v)).unwrap();
}
d.into_py(py)
})
.collect();
Ok(py_rows)
}
fn query_with_params(
&mut self,
py: Python<'_>,
gql: &str,
params: &pyo3::Bound<'_, PyDict>,
) -> PyResult<Vec<PyObject>> {
let graph = self.graph.as_mut()
.ok_or_else(|| PyRuntimeError::new_err("database is closed"))?;
let rust_params: HashMap<String, Value> = params
.iter()
.filter_map(|(k, v)| {
let key = k.str().ok()?.extract::<String>().ok()?;
Some((key, py_to_value(py, &v.into_py(py))))
})
.collect();
let (rows, _ops) = crate::query_capturing_with_params(gql, graph, &mut self.next_txn_id, rust_params)
.map_err(db_err_to_py)?;
let py_rows = rows
.into_iter()
.map(|row| {
let d = PyDict::new_bound(py);
for (k, v) in &row {
d.set_item(k, value_to_py(py, v)).unwrap();
}
d.into_py(py)
})
.collect();
Ok(py_rows)
}
fn query_df(&mut self, py: Python<'_>, gql: &str) -> PyResult<PyObject> {
let graph = self.graph.as_mut()
.ok_or_else(|| PyRuntimeError::new_err("database is closed"))?;
let (rows, _ops) =
crate::query_capturing(gql, graph, &mut self.next_txn_id)
.map_err(db_err_to_py)?;
let cols = rows_to_columnar(py, &rows)?;
let pl = py.import_bound("polars").map_err(|_| {
PyRuntimeError::new_err("polars is not installed. Run: pip install polars")
})?;
let df = pl.getattr("DataFrame")?.call1((cols,))?;
Ok(df.into_py(py))
}
fn query_pandas(&mut self, py: Python<'_>, gql: &str) -> PyResult<PyObject> {
let graph = self.graph.as_mut()
.ok_or_else(|| PyRuntimeError::new_err("database is closed"))?;
let (rows, _ops) =
crate::query_capturing(gql, graph, &mut self.next_txn_id)
.map_err(db_err_to_py)?;
let cols = rows_to_columnar(py, &rows)?;
let pd = py.import_bound("pandas").map_err(|_| {
PyRuntimeError::new_err("pandas is not installed. Run: pip install pandas")
})?;
let df = pd.getattr("DataFrame")?.call1((cols,))?;
Ok(df.into_py(py))
}
fn begin(&mut self) -> PyResult<()> {
self.graph.as_mut()
.ok_or_else(|| PyRuntimeError::new_err("database is closed"))?
.begin_transaction().map_err(db_err_to_py)
}
fn commit(&mut self) -> PyResult<()> {
self.graph.as_mut()
.ok_or_else(|| PyRuntimeError::new_err("database is closed"))?
.commit_transaction().map_err(db_err_to_py)
}
fn rollback(&mut self) -> PyResult<()> {
self.graph.as_mut()
.ok_or_else(|| PyRuntimeError::new_err("database is closed"))?
.rollback_transaction().map_err(db_err_to_py)
}
fn clear(&mut self) -> PyResult<()> {
self.graph.as_mut()
.ok_or_else(|| PyRuntimeError::new_err("database is closed"))?
.clear()
.map_err(db_err_to_py)
}
#[pyo3(signature = (path, label=None))]
fn load_csv_nodes(
&mut self,
path: &str,
label: Option<&str>,
) -> PyResult<HashMap<String, String>> {
use std::fs::File;
use crate::csv_import::load_nodes_csv;
let graph = self.graph.as_mut()
.ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("database is closed"))?;
let file = File::open(path)
.map_err(|e| pyo3::exceptions::PyOSError::new_err(e.to_string()))?;
let result = load_nodes_csv(file, graph, label).map_err(db_err_to_py)?;
graph.csv_id_map = result.id_map.clone();
let out = crate::csv_import::id_map_to_strings(&result.id_map);
Ok(out)
}
#[pyo3(signature = (path, id_map=None, label=None))]
fn load_csv_edges(
&mut self,
py: Python<'_>,
path: &str,
id_map: Option<HashMap<String, String>>,
label: Option<&str>,
) -> PyResult<PyObject> {
use std::fs::File;
use crate::csv_import::{load_edges_csv, id_map_from_strings};
use pyo3::types::PyDict;
let graph = self.graph.as_mut()
.ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("database is closed"))?;
let file = File::open(path)
.map_err(|e| pyo3::exceptions::PyOSError::new_err(e.to_string()))?;
let resolved = if let Some(m) = id_map.as_ref() {
id_map_from_strings(m)
} else {
graph.csv_id_map.clone()
};
let result = load_edges_csv(file, graph, &resolved, label).map_err(db_err_to_py)?;
let d = PyDict::new_bound(py);
d.set_item("inserted", result.inserted)?;
d.set_item("skipped", result.skipped)?;
Ok(d.into_py(py))
}
#[pyo3(signature = (nodes_path, edges_path, node_label=None, edge_label=None))]
fn load_csv(
&mut self,
py: Python<'_>,
nodes_path: &str,
edges_path: &str,
node_label: Option<&str>,
edge_label: Option<&str>,
) -> PyResult<PyObject> {
use std::fs::File;
use crate::csv_import::{load_nodes_csv, load_edges_csv};
use pyo3::types::PyDict;
let graph = self.graph.as_mut()
.ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("database is closed"))?;
let nf = File::open(nodes_path)
.map_err(|e| pyo3::exceptions::PyOSError::new_err(e.to_string()))?;
let nr = load_nodes_csv(nf, graph, node_label).map_err(db_err_to_py)?;
let id_map = nr.id_map.clone();
graph.csv_id_map = id_map.clone();
let ef = File::open(edges_path)
.map_err(|e| pyo3::exceptions::PyOSError::new_err(e.to_string()))?;
let er = load_edges_csv(ef, graph, &id_map, edge_label).map_err(db_err_to_py)?;
let d = PyDict::new_bound(py);
d.set_item("nodes_inserted", nr.inserted)?;
d.set_item("edges_inserted", er.inserted)?;
d.set_item("skipped", er.skipped)?;
Ok(d.into_py(py))
}
fn close(&mut self) -> PyResult<()> {
if let Some(mut graph) = self.graph.take() {
if graph.is_in_transaction() {
let _ = graph.rollback_transaction();
}
if let Some(storage) = self.storage.take() {
storage.checkpoint(&graph).map_err(db_err_to_py)?;
}
}
Ok(())
}
fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __exit__(
&mut self,
_exc_type: PyObject,
_exc_val: PyObject,
_exc_tb: PyObject,
) -> PyResult<bool> {
self.close()?;
Ok(false) }
}
fn graph_dir(name: &str) -> PyResult<PathBuf> {
if name.is_empty() || name.len() > 64 {
return Err(PyRuntimeError::new_err(
"Graph name must be 1–64 characters",
));
}
if !name
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
{
return Err(PyRuntimeError::new_err(
"Graph name must contain only alphanumeric characters, '_', or '-'",
));
}
let base = dirs::data_dir().ok_or_else(|| {
PyOSError::new_err("Cannot determine platform data directory")
})?;
Ok(base.join("minigdb").join("graphs").join(name))
}
#[pyfunction]
fn open(name: &str) -> PyResult<MiniGdb> {
let dir = graph_dir(name)?;
std::fs::create_dir_all(&dir)
.map_err(|e| PyOSError::new_err(e.to_string()))?;
let (storage, graph) = StorageManager::open(&dir).map_err(db_err_to_py)?;
Ok(MiniGdb { graph: Some(graph), storage: Some(storage), next_txn_id: 0 })
}
#[pyfunction]
fn open_at(path: &str) -> PyResult<MiniGdb> {
let dir = std::path::Path::new(path);
std::fs::create_dir_all(dir)
.map_err(|e| PyOSError::new_err(e.to_string()))?;
let (storage, graph) = StorageManager::open(dir).map_err(db_err_to_py)?;
Ok(MiniGdb { graph: Some(graph), storage: Some(storage), next_txn_id: 0 })
}
fn json_to_py(py: Python<'_>, v: &serde_json::Value) -> PyObject {
match v {
serde_json::Value::Null => py.None(),
serde_json::Value::Bool(b) => b.into_py(py),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
i.into_py(py)
} else {
n.as_f64().unwrap_or(f64::NAN).into_py(py)
}
}
serde_json::Value::String(s) => s.into_py(py),
serde_json::Value::Array(a) => a
.iter()
.map(|x| json_to_py(py, x))
.collect::<Vec<_>>()
.into_py(py),
serde_json::Value::Object(o) => {
let d = PyDict::new_bound(py);
for (k, v) in o {
d.set_item(k, json_to_py(py, v)).unwrap();
}
d.into_py(py)
}
}
}
fn json_rows_to_columnar(
py: Python<'_>,
rows: &[std::collections::HashMap<String, serde_json::Value>],
) -> PyResult<PyObject> {
let mut col_order: Vec<String> = Vec::new();
let mut col_seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
for row in rows {
for k in row.keys() {
if col_seen.insert(k.as_str()) {
col_order.push(k.clone());
}
}
}
let d = PyDict::new_bound(py);
for col in &col_order {
let vals: Vec<PyObject> = rows
.iter()
.map(|row| {
row.get(col.as_str())
.map_or_else(|| py.None(), |v| json_to_py(py, v))
})
.collect();
d.set_item(col, vals)?;
}
Ok(d.into_py(py))
}
#[pyclass]
pub struct MiniGdbClient {
reader: std::io::BufReader<std::net::TcpStream>,
writer: std::net::TcpStream,
next_id: u64,
}
impl MiniGdbClient {
fn send_query(
&mut self,
gql: &str,
) -> PyResult<Vec<std::collections::HashMap<String, serde_json::Value>>> {
use std::io::{BufRead, Write};
let id = self.next_id;
self.next_id += 1;
let req = serde_json::json!({"id": id, "query": gql});
let mut line = serde_json::to_string(&req).unwrap();
line.push('\n');
self.writer
.write_all(line.as_bytes())
.map_err(|e| PyOSError::new_err(e.to_string()))?;
let mut resp_line = String::new();
self.reader
.read_line(&mut resp_line)
.map_err(|e| PyOSError::new_err(e.to_string()))?;
let resp: serde_json::Value = serde_json::from_str(resp_line.trim())
.map_err(|e| PyRuntimeError::new_err(format!("server sent invalid JSON: {e}")))?;
if let Some(err) = resp.get("error").and_then(|e| e.as_str()) {
return Err(PyRuntimeError::new_err(err.to_string()));
}
let rows = resp
.get("rows")
.and_then(|r| r.as_array())
.map(|arr| {
arr.iter()
.filter_map(|row| {
row.as_object().map(|obj| {
obj.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<std::collections::HashMap<_, _>>()
})
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
Ok(rows)
}
}
#[pymethods]
impl MiniGdbClient {
fn query(&mut self, py: Python<'_>, gql: &str) -> PyResult<Vec<PyObject>> {
let rows = self.send_query(gql)?;
let py_rows = rows
.iter()
.map(|row| {
let d = PyDict::new_bound(py);
for (k, v) in row {
d.set_item(k, json_to_py(py, v)).unwrap();
}
d.into_py(py)
})
.collect();
Ok(py_rows)
}
fn query_df(&mut self, py: Python<'_>, gql: &str) -> PyResult<PyObject> {
let rows = self.send_query(gql)?;
let cols = json_rows_to_columnar(py, &rows)?;
let pl = py.import_bound("polars").map_err(|_| {
PyRuntimeError::new_err("polars is not installed. Run: pip install polars")
})?;
let df = pl.getattr("DataFrame")?.call1((cols,))?;
Ok(df.into_py(py))
}
fn query_pandas(&mut self, py: Python<'_>, gql: &str) -> PyResult<PyObject> {
let rows = self.send_query(gql)?;
let cols = json_rows_to_columnar(py, &rows)?;
let pd = py.import_bound("pandas").map_err(|_| {
PyRuntimeError::new_err("pandas is not installed. Run: pip install pandas")
})?;
let df = pd.getattr("DataFrame")?.call1((cols,))?;
Ok(df.into_py(py))
}
fn begin(&mut self) -> PyResult<()> {
self.send_query("BEGIN").map(|_| ())
}
fn commit(&mut self) -> PyResult<()> {
self.send_query("COMMIT").map(|_| ())
}
fn rollback(&mut self) -> PyResult<()> {
self.send_query("ROLLBACK").map(|_| ())
}
fn close(&mut self) -> PyResult<()> {
self.writer
.shutdown(std::net::Shutdown::Both)
.map_err(|e| PyOSError::new_err(e.to_string()))
}
fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __exit__(
&mut self,
_exc_type: PyObject,
_exc_val: PyObject,
_exc_tb: PyObject,
) -> PyResult<bool> {
let _ = self.close();
Ok(false)
}
}
#[pyfunction]
fn connect(addr: &str) -> PyResult<MiniGdbClient> {
use std::io::BufRead;
let stream = std::net::TcpStream::connect(addr)
.map_err(|e| PyOSError::new_err(format!("cannot connect to {addr}: {e}")))?;
let writer = stream
.try_clone()
.map_err(|e| PyOSError::new_err(e.to_string()))?;
let mut reader = std::io::BufReader::new(stream);
let mut hello = String::new();
reader
.read_line(&mut hello)
.map_err(|e| PyOSError::new_err(format!("reading hello from {addr}: {e}")))?;
Ok(MiniGdbClient { reader, writer, next_id: 0 })
}
#[pymodule]
fn minigdb(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<MiniGdb>()?;
m.add_class::<MiniGdbClient>()?;
m.add_function(wrap_pyfunction!(open, m)?)?;
m.add_function(wrap_pyfunction!(open_at, m)?)?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
Ok(())
}