use pgrx::prelude::*;
use pgrx::JsonB;
use serde::{Deserialize, Serialize};
use super::optimizer::OptimizationTarget;
use super::{QueryTrajectory, LEARNING_MANAGER};
#[derive(Debug, Serialize, Deserialize)]
pub struct LearningConfig {
#[serde(default = "default_max_trajectories")]
pub max_trajectories: usize,
#[serde(default = "default_num_clusters")]
pub num_clusters: usize,
#[serde(default)]
pub auto_tune_interval: u64,
}
fn default_max_trajectories() -> usize {
1000
}
fn default_num_clusters() -> usize {
10
}
impl Default for LearningConfig {
fn default() -> Self {
Self {
max_trajectories: 1000,
num_clusters: 10,
auto_tune_interval: 0,
}
}
}
#[pg_extern]
fn ruvector_enable_learning(
table_name: &str,
config: Option<JsonB>,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let config: LearningConfig = match config {
Some(jsonb) => serde_json::from_value(jsonb.0.clone())?,
None => LearningConfig::default(),
};
LEARNING_MANAGER.enable_for_table(table_name, config.max_trajectories);
Ok(format!(
"Learning enabled for table '{}' with max_trajectories={}",
table_name, config.max_trajectories
))
}
#[pg_extern]
fn ruvector_record_feedback(
table_name: &str,
query_vector: Vec<f32>,
relevant_ids: Vec<i64>,
irrelevant_ids: Vec<i64>,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tracker = LEARNING_MANAGER
.get_tracker(table_name)
.ok_or_else(|| format!("Learning not enabled for table: {}", table_name))?;
let mut recent = tracker.get_recent(10);
if let Some(traj) = recent.iter_mut().find(|t| t.query_vector == query_vector) {
traj.add_feedback(
relevant_ids.iter().map(|&id| id as u64).collect(),
irrelevant_ids.iter().map(|&id| id as u64).collect(),
);
tracker.record(traj.clone());
Ok(format!(
"Feedback recorded: {} relevant, {} irrelevant",
relevant_ids.len(),
irrelevant_ids.len()
))
} else {
Err("No recent trajectory found matching query vector".into())
}
}
#[pg_extern]
fn ruvector_learning_stats(
table_name: &str,
) -> Result<JsonB, Box<dyn std::error::Error + Send + Sync>> {
let tracker = LEARNING_MANAGER
.get_tracker(table_name)
.ok_or_else(|| format!("Learning not enabled for table: {}", table_name))?;
let bank = LEARNING_MANAGER
.get_reasoning_bank(table_name)
.ok_or_else(|| format!("ReasoningBank not found for table: {}", table_name))?;
let trajectory_stats = tracker.stats();
let bank_stats = bank.stats();
let stats = serde_json::json!({
"trajectories": {
"total": trajectory_stats.total_trajectories,
"with_feedback": trajectory_stats.trajectories_with_feedback,
"avg_latency_us": trajectory_stats.avg_latency_us,
"avg_precision": trajectory_stats.avg_precision,
"avg_recall": trajectory_stats.avg_recall,
},
"patterns": {
"total": bank_stats.total_patterns,
"total_samples": bank_stats.total_samples,
"avg_confidence": bank_stats.avg_confidence,
"total_usage": bank_stats.total_usage,
}
});
Ok(JsonB(stats))
}
#[pg_extern]
fn ruvector_auto_tune(
table_name: &str,
optimize_for: default!(&str, "'balanced'"),
sample_queries: Option<JsonB>,
) -> Result<JsonB, Box<dyn std::error::Error + Send + Sync>> {
let optimizer = LEARNING_MANAGER
.get_optimizer(table_name)
.ok_or_else(|| format!("Learning not enabled for table: {}", table_name))?;
let target = match optimize_for {
"speed" => OptimizationTarget::Speed,
"accuracy" => OptimizationTarget::Accuracy,
_ => OptimizationTarget::Balanced,
};
let patterns_extracted = LEARNING_MANAGER.extract_patterns(table_name, 10)?;
let mut recommendations = Vec::new();
if let Some(JsonB(json_val)) = sample_queries {
if let Some(queries_array) = json_val.as_array() {
for query_val in queries_array {
if let Some(query_array) = query_val.as_array() {
let query: Vec<f32> = query_array
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let params = optimizer.optimize_with_target(&query, target);
recommendations.push(serde_json::json!({
"ef_search": params.ef_search,
"probes": params.probes,
"confidence": params.confidence,
}));
}
}
}
}
let result = serde_json::json!({
"patterns_extracted": patterns_extracted,
"optimize_for": optimize_for,
"recommendations": recommendations,
});
Ok(JsonB(result))
}
#[pg_extern]
fn ruvector_consolidate_patterns(
table_name: &str,
similarity_threshold: default!(f64, 0.9),
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let bank = LEARNING_MANAGER
.get_reasoning_bank(table_name)
.ok_or_else(|| format!("Learning not enabled for table: {}", table_name))?;
let merged = bank.consolidate(similarity_threshold);
Ok(format!(
"Consolidated {} similar patterns with threshold {}",
merged, similarity_threshold
))
}
#[pg_extern]
fn ruvector_prune_patterns(
table_name: &str,
min_usage: default!(i32, 5),
min_confidence: default!(f64, 0.5),
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let bank = LEARNING_MANAGER
.get_reasoning_bank(table_name)
.ok_or_else(|| format!("Learning not enabled for table: {}", table_name))?;
let pruned = bank.prune(min_usage as usize, min_confidence);
Ok(format!(
"Pruned {} patterns with min_usage={}, min_confidence={}",
pruned, min_usage, min_confidence
))
}
#[pg_extern]
fn ruvector_get_search_params(
table_name: &str,
query_vector: Vec<f32>,
) -> Result<JsonB, Box<dyn std::error::Error + Send + Sync>> {
let optimizer = LEARNING_MANAGER
.get_optimizer(table_name)
.ok_or_else(|| format!("Learning not enabled for table: {}", table_name))?;
let params = optimizer.optimize(&query_vector);
let result = serde_json::json!({
"ef_search": params.ef_search,
"probes": params.probes,
"confidence": params.confidence,
});
Ok(JsonB(result))
}
#[pg_extern]
fn ruvector_extract_patterns(
table_name: &str,
num_clusters: default!(i32, 10),
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let patterns_extracted =
LEARNING_MANAGER.extract_patterns(table_name, num_clusters as usize)?;
Ok(format!(
"Extracted {} patterns from trajectories using {} clusters",
patterns_extracted, num_clusters
))
}
#[pg_extern]
fn ruvector_record_trajectory(
table_name: &str,
query_vector: Vec<f32>,
result_ids: Vec<i64>,
latency_us: i64,
ef_search: i32,
probes: i32,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tracker = LEARNING_MANAGER
.get_tracker(table_name)
.ok_or_else(|| format!("Learning not enabled for table: {}", table_name))?;
let trajectory = QueryTrajectory::new(
query_vector,
result_ids.iter().map(|&id| id as u64).collect(),
latency_us as u64,
ef_search as usize,
probes as usize,
);
tracker.record(trajectory);
Ok(format!(
"Trajectory recorded for {} results",
result_ids.len()
))
}
#[pg_extern]
fn ruvector_clear_learning(
table_name: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let bank = LEARNING_MANAGER
.get_reasoning_bank(table_name)
.ok_or_else(|| format!("Learning not enabled for table: {}", table_name))?;
bank.clear();
Ok(format!(
"Cleared all learning data for table '{}'",
table_name
))
}
#[cfg(feature = "pg_test")]
#[pg_schema]
mod tests {
use super::*;
#[pg_test]
fn test_enable_learning() {
let result = ruvector_enable_learning("test_table", None);
assert!(result.is_ok());
}
#[pg_test]
fn test_learning_stats_empty() {
ruvector_enable_learning("test_stats", None).unwrap();
let stats = ruvector_learning_stats("test_stats");
assert!(stats.is_ok());
}
#[pg_test]
fn test_record_trajectory() {
ruvector_enable_learning("test_trajectory", None).unwrap();
let result = ruvector_record_trajectory(
"test_trajectory",
vec![1.0, 2.0, 3.0],
vec![1, 2, 3],
1000,
50,
10,
);
assert!(result.is_ok());
}
#[pg_test]
fn test_extract_patterns() {
ruvector_enable_learning("test_patterns", None).unwrap();
for i in 0..20 {
ruvector_record_trajectory(
"test_patterns",
vec![i as f32, (i * 2) as f32],
vec![i, i + 1],
1000 + i * 100,
50,
10,
)
.unwrap();
}
let result = ruvector_extract_patterns("test_patterns", 5);
assert!(result.is_ok());
}
#[pg_test]
fn test_auto_tune() {
ruvector_enable_learning("test_autotune", None).unwrap();
for i in 0..10 {
ruvector_record_trajectory(
"test_autotune",
vec![i as f32, (i * 2) as f32],
vec![i],
1000,
50,
10,
)
.unwrap();
}
let result = ruvector_auto_tune("test_autotune", "balanced", None);
assert!(result.is_ok());
}
#[pg_test]
fn test_get_search_params() {
ruvector_enable_learning("test_search_params", None).unwrap();
for i in 0..20 {
ruvector_record_trajectory(
"test_search_params",
vec![i as f32, 0.0],
vec![i],
1000,
50,
10,
)
.unwrap();
}
ruvector_extract_patterns("test_search_params", 3).unwrap();
let result = ruvector_get_search_params("test_search_params", vec![5.0, 0.0]);
assert!(result.is_ok());
}
#[pg_test]
fn test_consolidate_patterns() {
ruvector_enable_learning("test_consolidate", None).unwrap();
for i in 0..30 {
ruvector_record_trajectory(
"test_consolidate",
vec![i as f32 / 10.0, 0.0],
vec![i],
1000,
50,
10,
)
.unwrap();
}
ruvector_extract_patterns("test_consolidate", 10).unwrap();
let result = ruvector_consolidate_patterns("test_consolidate", 0.95);
assert!(result.is_ok());
}
#[pg_test]
fn test_prune_patterns() {
ruvector_enable_learning("test_prune", None).unwrap();
for i in 0..20 {
ruvector_record_trajectory("test_prune", vec![i as f32, 0.0], vec![i], 1000, 50, 10)
.unwrap();
}
ruvector_extract_patterns("test_prune", 5).unwrap();
let result = ruvector_prune_patterns("test_prune", 100, 0.9);
assert!(result.is_ok());
}
#[pg_test]
fn test_clear_learning() {
ruvector_enable_learning("test_clear", None).unwrap();
ruvector_record_trajectory("test_clear", vec![1.0, 2.0], vec![1], 1000, 50, 10).unwrap();
let result = ruvector_clear_learning("test_clear");
assert!(result.is_ok());
let stats = ruvector_learning_stats("test_clear").unwrap();
let stats_obj = stats.0.as_object().unwrap();
let patterns = stats_obj.get("patterns").unwrap().as_object().unwrap();
assert_eq!(patterns.get("total").unwrap().as_u64().unwrap(), 0);
}
}