use std::path::PathBuf;
use std::thread::JoinHandle;
use tokio::sync::mpsc;
use super::db;
use super::event::StoreEvent;
use super::path;
use super::writer;
const CHANNEL_CAPACITY: usize = 10_000;
pub struct StoreConfig {
pub db_path: PathBuf,
pub mcpr_version: String,
}
pub struct Store {
tx: mpsc::Sender<StoreEvent>,
writer_handle: Option<JoinHandle<()>>,
db_path: PathBuf,
}
impl Store {
pub fn open(config: StoreConfig) -> Result<Self, StoreError> {
path::ensure_parent_dir(&config.db_path)
.map_err(|e| StoreError::Io(format!("failed to create db directory: {e}")))?;
let conn = db::open_connection(&config.db_path)
.map_err(|e| StoreError::Sqlite(format!("failed to open database: {e}")))?;
db::run_migrations(&conn, &config.mcpr_version)
.map_err(|e| StoreError::Sqlite(format!("schema migration failed: {e}")))?;
let (tx, rx) = mpsc::channel::<StoreEvent>(CHANNEL_CAPACITY);
let writer_handle = std::thread::Builder::new()
.name("mcpr-store-writer".into())
.spawn(move || {
writer::run_writer_loop(conn, rx);
})
.map_err(|e| StoreError::Io(format!("failed to spawn writer thread: {e}")))?;
Ok(Store {
tx,
writer_handle: Some(writer_handle),
db_path: config.db_path,
})
}
pub fn record(&self, event: StoreEvent) {
let _ = self.tx.try_send(event);
}
pub fn db_path(&self) -> &PathBuf {
&self.db_path
}
pub fn shutdown(&mut self) {
let (dead_tx, _) = mpsc::channel(1);
let old_tx = std::mem::replace(&mut self.tx, dead_tx);
drop(old_tx);
if let Some(handle) = self.writer_handle.take()
&& let Err(e) = handle.join()
{
eprintln!("mcpr-store: writer thread panicked: {e:?}");
}
}
}
impl Drop for Store {
fn drop(&mut self) {
if self.writer_handle.is_some() {
self.shutdown();
}
}
}
#[derive(Debug)]
pub enum StoreError {
Io(String),
Sqlite(String),
}
impl std::fmt::Display for StoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StoreError::Io(msg) => write!(f, "store I/O error: {msg}"),
StoreError::Sqlite(msg) => write!(f, "store SQLite error: {msg}"),
}
}
}
impl std::error::Error for StoreError {}
#[cfg(test)]
#[allow(non_snake_case)]
mod tests {
use super::*;
use crate::store::event::{RequestEvent, RequestStatus, SessionEvent};
#[test]
fn store__open_record_shutdown() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let mut store = Store::open(StoreConfig {
db_path: db_path.clone(),
mcpr_version: "test".into(),
})
.unwrap();
store.record(StoreEvent::Session(SessionEvent {
session_id: "s1".into(),
proxy: "test-proxy".into(),
started_at: 1000,
client_name: Some("test-client".into()),
client_version: Some("0.1".into()),
client_platform: Some("unknown".into()),
}));
store.record(StoreEvent::Request(RequestEvent {
request_id: "r1".into(),
ts: 1001,
proxy: "test-proxy".into(),
session_id: Some("s1".into()),
method: "tools/call".into(),
tool: Some("test_tool".into()),
resource_uri: None,
prompt_name: None,
latency_us: 50_000,
status: RequestStatus::Ok,
error_code: None,
error_msg: None,
bytes_in: Some(100),
bytes_out: Some(200),
}));
store.shutdown();
let conn = db::open_connection(&db_path).unwrap();
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM requests", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 1);
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM sessions", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 1);
let tool: String = conn
.query_row(
"SELECT tool FROM requests WHERE request_id = 'r1'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(tool, "test_tool");
}
}