use std::collections::HashMap;
use std::time::Instant;
use serde_json::Value;
#[derive(Debug, Clone)]
pub struct CachedColumn {
pub name: String,
pub data_type: String,
pub is_nullable: bool,
pub column_default: Option<String>,
}
#[derive(Debug, Clone)]
pub struct VectorColumnInfo {
pub column_name: String,
pub dimensions: Option<i32>,
pub index_type: Option<String>,
pub distance_ops: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CachedTable {
pub schema: String,
pub name: String,
pub pk_column: Option<String>,
pub columns: Vec<CachedColumn>,
pub vector_columns: Vec<VectorColumnInfo>,
}
#[derive(Debug, Clone)]
pub struct SchemaCache {
pub tables: HashMap<String, CachedTable>,
}
impl SchemaCache {
pub fn empty() -> Self {
Self {
tables: HashMap::new(),
}
}
pub fn columns_for_table(&self, table: &str) -> Option<&Vec<CachedColumn>> {
self.tables.get(table).map(|t| &t.columns)
}
pub fn pk_for_table(&self, table: &str) -> Option<&str> {
self.tables
.get(table)
.and_then(|t| t.pk_column.as_deref())
}
pub fn vector_columns_for_table(&self, table: &str) -> Option<&Vec<VectorColumnInfo>> {
self.tables.get(table).map(|t| &t.vector_columns)
}
pub fn to_summary(&self) -> String {
if self.tables.is_empty() {
return "No tables found in database.".to_string();
}
let mut tables: Vec<&CachedTable> = self.tables.values().collect();
tables.sort_by(|a, b| a.name.cmp(&b.name));
let mut out = String::new();
for (i, table) in tables.iter().enumerate() {
if i > 0 {
out.push('\n');
}
let pk_info = match &table.pk_column {
Some(pk) => format!(" (pk: {})", pk),
None => String::new(),
};
out.push_str(&format!(
"TABLE: {}.{}{}\n",
table.schema, table.name, pk_info
));
for col in &table.columns {
let nullable = if col.is_nullable { "" } else { "NOT NULL" };
let default = match &col.column_default {
Some(d) => format!(" DEFAULT {}", d),
None => String::new(),
};
let vector_info = table
.vector_columns
.iter()
.find(|vc| vc.column_name == col.name)
.map(|vc| {
let dim = vc
.dimensions
.map(|d| format!("({})", d))
.unwrap_or_default();
let idx = match (&vc.index_type, &vc.distance_ops) {
(Some(it), Some(ops)) => format!(" [{}, {}]", it, ops),
(Some(it), None) => format!(" [{}]", it),
_ => String::new(),
};
format!("vector{}{}", dim, idx)
});
let dtype = vector_info.unwrap_or_else(|| col.data_type.clone());
out.push_str(&format!(
" {:<24} {:<18} {}{}\n",
col.name, dtype, nullable, default
));
}
}
out
}
}
pub mod introspect {
pub const COLUMNS_QUERY: &str = r#"
SELECT
c.table_schema,
c.table_name,
c.column_name,
c.data_type,
c.is_nullable,
c.column_default,
c.udt_name
FROM information_schema.columns c
JOIN information_schema.tables t
ON c.table_schema = t.table_schema
AND c.table_name = t.table_name
WHERE t.table_type = 'BASE TABLE'
AND c.table_schema NOT IN ('pg_catalog', 'information_schema')
ORDER BY c.table_schema, c.table_name, c.ordinal_position
"#;
pub const PK_QUERY: &str = r#"
SELECT
tc.table_schema,
tc.table_name,
kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema NOT IN ('pg_catalog', 'information_schema')
ORDER BY tc.table_schema, tc.table_name
"#;
pub const VECTOR_COLUMNS_QUERY: &str = r#"
SELECT
c.table_schema,
c.table_name,
c.column_name,
a.atttypmod as dimensions
FROM information_schema.columns c
JOIN pg_catalog.pg_attribute a
ON a.attname = c.column_name
JOIN pg_catalog.pg_class cl
ON cl.relname = c.table_name
AND a.attrelid = cl.oid
JOIN pg_catalog.pg_namespace n
ON n.nspname = c.table_schema
AND cl.relnamespace = n.oid
WHERE c.udt_name = 'vector'
AND c.table_schema NOT IN ('pg_catalog', 'information_schema')
ORDER BY c.table_schema, c.table_name
"#;
pub const VECTOR_INDEXES_QUERY: &str = r#"
SELECT
schemaname,
tablename,
indexname,
indexdef
FROM pg_indexes
WHERE indexdef LIKE '%vector%'
AND schemaname NOT IN ('pg_catalog', 'information_schema')
"#;
}
#[derive(Debug, Clone)]
pub struct CachedResultSet {
pub id: String,
pub table: String,
pub pk_column: String,
pub ids: Vec<Value>,
pub created_at: Instant,
}
pub struct ResultSetCache {
sets: HashMap<String, CachedResultSet>,
counter: u64,
max_sets: usize,
}
impl ResultSetCache {
pub fn new(max_sets: usize) -> Self {
Self {
sets: HashMap::new(),
counter: 0,
max_sets,
}
}
pub fn store(
&mut self,
table: String,
pk_column: String,
ids: Vec<Value>,
) -> String {
if self.sets.len() >= self.max_sets {
self.evict_oldest();
}
self.counter += 1;
let id = format!("rs_{:04}", self.counter);
self.sets.insert(
id.clone(),
CachedResultSet {
id: id.clone(),
table,
pk_column,
ids,
created_at: Instant::now(),
},
);
id
}
pub fn get(&self, id: &str) -> Option<&CachedResultSet> {
self.sets.get(id)
}
fn evict_oldest(&mut self) {
if let Some(oldest_key) = self
.sets
.iter()
.min_by_key(|(_, v)| v.created_at)
.map(|(k, _)| k.clone())
{
self.sets.remove(&oldest_key);
}
}
pub fn len(&self) -> usize {
self.sets.len()
}
pub fn is_empty(&self) -> bool {
self.sets.is_empty()
}
pub fn clear(&mut self) {
self.sets.clear();
self.counter = 0;
}
}
impl std::fmt::Debug for ResultSetCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResultSetCache")
.field("count", &self.sets.len())
.field("max", &self.max_sets)
.field("counter", &self.counter)
.finish()
}
}