use crate::Result;
#[cfg(any(feature = "tabpfn_http", feature = "llm_http"))]
fn warn_if_remote_plaintext_http(url: &str, backend: &'static str) {
let lower = url.to_ascii_lowercase();
if !lower.starts_with("http://") {
return;
}
let rest = &url[7..]; let host_end = rest
.find(|c: char| c == '/' || c == ':' || c == '?')
.unwrap_or(rest.len());
let host = &rest[..host_end];
let is_loopback = matches!(host, "127.0.0.1" | "::1" | "localhost")
|| host.starts_with("[::1]")
|| host.starts_with("127.");
if is_loopback {
return;
}
if std::env::var("SAMKHYA_ALLOW_REMOTE_HTTP").as_deref() == Ok("1") {
return;
}
log::warn!(
"samkhya {backend} corrector configured with plaintext HTTP to non-loopback host {host}; \
features and baseline_estimate will travel unencrypted. Use https:// or set \
SAMKHYA_ALLOW_REMOTE_HTTP=1 to silence this warning."
);
}
#[derive(Debug, Clone, Default)]
pub struct CorrectionFeatures {
pub baseline_estimate: u64,
pub left_input_rows: Option<u64>,
pub right_input_rows: Option<u64>,
pub left_distinct: Option<u64>,
pub right_distinct: Option<u64>,
pub predicate_count: u32,
pub join_depth: u32,
}
impl CorrectionFeatures {
pub fn to_vec(&self) -> Vec<f64> {
vec![
self.baseline_estimate as f64,
self.left_input_rows.unwrap_or(0) as f64,
self.right_input_rows.unwrap_or(0) as f64,
self.left_distinct.unwrap_or(0) as f64,
self.right_distinct.unwrap_or(0) as f64,
f64::from(self.predicate_count),
f64::from(self.join_depth),
]
}
pub const FEATURE_LEN: usize = 7;
}
pub trait Corrector: Send + Sync {
fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>>;
fn name(&self) -> &'static str;
}
pub struct IdentityCorrector;
impl Corrector for IdentityCorrector {
fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
Ok(Some(features.baseline_estimate))
}
fn name(&self) -> &'static str {
"identity"
}
}
pub struct TabPfnStub;
impl Corrector for TabPfnStub {
fn correct(&self, _features: &CorrectionFeatures) -> Result<Option<u64>> {
Ok(None)
}
fn name(&self) -> &'static str {
"tabpfn-stub"
}
}
#[cfg(feature = "gbt")]
pub mod gbt {
use gbdt::config::{Config, Loss};
use gbdt::decision_tree::{Data, DataVec};
use gbdt::gradient_boost::GBDT;
use super::{CorrectionFeatures, Corrector};
use crate::feedback::Observation;
use crate::lpbound::saturating_clamp;
use crate::{Error, Result};
#[derive(Debug, Clone)]
pub struct GbtOptions {
pub learning_rate: f64,
pub max_depth: u32,
pub num_trees: u32,
pub ceiling: u64,
pub min_leaf_size: usize,
}
impl Default for GbtOptions {
fn default() -> Self {
Self {
learning_rate: 0.1,
max_depth: 4,
num_trees: 50,
ceiling: u64::MAX,
min_leaf_size: 1,
}
}
}
pub struct GbtCorrector {
model: GBDT,
ceiling: u64,
}
impl GbtCorrector {
pub fn train(observations: &[Observation], options: GbtOptions) -> Result<Self> {
if observations.is_empty() {
return Err(Error::Feedback(
"cannot train GbtCorrector: observation slice is empty".into(),
));
}
let mut training: DataVec = Vec::with_capacity(observations.len());
for obs in observations {
if obs.est_rows == 0 || obs.actual_rows == 0 {
continue;
}
let features = CorrectionFeatures {
baseline_estimate: obs.est_rows,
..Default::default()
};
let feature_f32: Vec<f32> =
features.to_vec().into_iter().map(|v| v as f32).collect();
let ratio_log = (obs.actual_rows as f64 / obs.est_rows as f64).ln() as f32;
training.push(Data::new_training_data(feature_f32, 1.0, ratio_log, None));
}
if training.is_empty() {
return Err(Error::Feedback(
"cannot train GbtCorrector: all observations had zero est or actual rows"
.into(),
));
}
let mut cfg = Config::new();
cfg.set_feature_size(CorrectionFeatures::FEATURE_LEN);
cfg.set_max_depth(options.max_depth);
cfg.set_iterations(options.num_trees as usize);
cfg.set_shrinkage(options.learning_rate as f32);
cfg.set_min_leaf_size(options.min_leaf_size);
cfg.set_loss(&loss_name(Loss::SquaredError));
let mut model = GBDT::new(&cfg);
model.fit(&mut training);
Ok(Self {
model,
ceiling: options.ceiling,
})
}
pub fn predict_log_ratio(&self, features: &CorrectionFeatures) -> f64 {
let feature_f32: Vec<f32> = features.to_vec().into_iter().map(|v| v as f32).collect();
let probe: DataVec = vec![Data::new_test_data(feature_f32, None)];
let preds = self.model.predict(&probe);
preds.first().copied().unwrap_or(0.0) as f64
}
pub fn ceiling(&self) -> u64 {
self.ceiling
}
}
impl Corrector for GbtCorrector {
fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
let log_ratio = self.predict_log_ratio(features);
let ratio = log_ratio.exp();
let scaled = features.baseline_estimate as f64 * ratio;
Ok(Some(saturating_clamp(scaled, self.ceiling)))
}
fn name(&self) -> &'static str {
"gbt"
}
}
fn loss_name(loss: Loss) -> String {
gbdt::config::loss2string(&loss)
}
}
#[cfg(feature = "additive_gbt")]
pub mod additive {
use gbdt::config::{Config, Loss};
use gbdt::decision_tree::{Data, DataVec};
use gbdt::gradient_boost::GBDT;
use std::sync::Mutex;
use super::{CorrectionFeatures, Corrector};
use crate::feedback::Observation;
use crate::lpbound::saturating_clamp;
use crate::{Error, Result};
#[derive(Debug, Clone)]
pub struct AdditiveGbtOptions {
pub learning_rate: f64,
pub max_depth: u32,
pub num_trees: u32,
pub ceiling: u64,
pub min_leaf_size: usize,
}
impl Default for AdditiveGbtOptions {
fn default() -> Self {
Self {
learning_rate: 0.1,
max_depth: 4,
num_trees: 50,
ceiling: u64::MAX,
min_leaf_size: 1,
}
}
}
pub struct AdditiveGbtCorrector {
model: Mutex<GBDT>,
ceiling: u64,
}
impl AdditiveGbtCorrector {
pub fn train(observations: &[Observation], options: AdditiveGbtOptions) -> Result<Self> {
if observations.is_empty() {
return Err(Error::Feedback(
"cannot train AdditiveGbtCorrector: observation slice is empty".into(),
));
}
let mut training: DataVec = Vec::with_capacity(observations.len());
for obs in observations {
let features = CorrectionFeatures {
baseline_estimate: obs.est_rows,
..Default::default()
};
let feature_f32: Vec<f32> =
features.to_vec().into_iter().map(|v| v as f32).collect();
let target = obs.actual_rows as f32;
training.push(Data::new_training_data(feature_f32, 1.0, target, None));
}
debug_assert!(!training.is_empty());
let mut cfg = Config::new();
cfg.set_feature_size(CorrectionFeatures::FEATURE_LEN);
cfg.set_max_depth(options.max_depth);
cfg.set_iterations(options.num_trees as usize);
cfg.set_shrinkage(options.learning_rate as f32);
cfg.set_min_leaf_size(options.min_leaf_size);
cfg.set_loss(&gbdt::config::loss2string(&Loss::SquaredError));
let mut model = GBDT::new(&cfg);
model.fit(&mut training);
Ok(Self {
model: Mutex::new(model),
ceiling: options.ceiling,
})
}
pub fn predict_rows(&self, features: &CorrectionFeatures) -> f64 {
let feature_f32: Vec<f32> = features.to_vec().into_iter().map(|v| v as f32).collect();
let probe: DataVec = vec![Data::new_test_data(feature_f32, None)];
let model = self.model.lock().expect("AdditiveGbtCorrector model lock");
let preds = model.predict(&probe);
preds.first().copied().unwrap_or(0.0) as f64
}
pub fn ceiling(&self) -> u64 {
self.ceiling
}
}
impl Corrector for AdditiveGbtCorrector {
fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
let raw = self.predict_rows(features).max(0.0);
Ok(Some(saturating_clamp(raw, self.ceiling)))
}
fn name(&self) -> &'static str {
"additive_gbt"
}
}
}
#[cfg(feature = "tabpfn_http")]
pub mod tabpfn {
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::{CorrectionFeatures, Corrector};
use crate::Result;
use crate::lpbound::saturating_clamp;
#[derive(Debug, Clone)]
pub struct TabPfnHttpOptions {
pub base_url: String,
pub timeout_ms: u64,
pub ceiling: u64,
}
impl Default for TabPfnHttpOptions {
fn default() -> Self {
Self {
base_url: "http://localhost:8765/infer".into(),
timeout_ms: 50,
ceiling: u64::MAX,
}
}
}
#[derive(Serialize)]
struct InferRequest<'a> {
features: &'a [f64],
baseline_estimate: u64,
}
#[derive(Deserialize)]
struct InferResponse {
estimate: u64,
}
pub struct TabPfnHttpCorrector {
options: TabPfnHttpOptions,
}
impl TabPfnHttpCorrector {
pub fn new(options: TabPfnHttpOptions) -> Self {
super::warn_if_remote_plaintext_http(&options.base_url, "tabpfn_http");
Self { options }
}
pub fn with_url(base_url: impl Into<String>) -> Self {
let opts = TabPfnHttpOptions {
base_url: base_url.into(),
..TabPfnHttpOptions::default()
};
super::warn_if_remote_plaintext_http(&opts.base_url, "tabpfn_http");
Self { options: opts }
}
pub fn options(&self) -> &TabPfnHttpOptions {
&self.options
}
fn try_infer(&self, features: &CorrectionFeatures) -> Option<u64> {
let feature_vec = features.to_vec();
let payload = InferRequest {
features: &feature_vec,
baseline_estimate: features.baseline_estimate,
};
let timeout = Duration::from_millis(self.options.timeout_ms);
let agent = ureq::AgentBuilder::new()
.timeout_connect(timeout)
.timeout_read(timeout)
.timeout_write(timeout)
.build();
let response = match agent.post(&self.options.base_url).send_json(&payload) {
Ok(r) => r,
Err(err) => {
log::debug!(
"tabpfn_http: request to {} failed: {}",
self.options.base_url,
err
);
return None;
}
};
match response.into_json::<InferResponse>() {
Ok(body) => Some(body.estimate),
Err(err) => {
log::debug!(
"tabpfn_http: response from {} failed to parse: {}",
self.options.base_url,
err
);
None
}
}
}
}
impl Corrector for TabPfnHttpCorrector {
fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
let Some(raw) = self.try_infer(features) else {
return Ok(None);
};
Ok(Some(saturating_clamp(raw as f64, self.options.ceiling)))
}
fn name(&self) -> &'static str {
"tabpfn-http"
}
}
}
#[cfg(feature = "llm_http")]
pub mod llm {
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::{CorrectionFeatures, Corrector};
use crate::Result;
use crate::lpbound::saturating_clamp;
pub const DEFAULT_TIMEOUT_MS: u64 = 2_000;
pub const MAX_TIMEOUT_MS: u64 = 60_000;
pub const DEFAULT_URL: &str = "http://127.0.0.1:8766/infer";
#[derive(Debug, Clone)]
pub struct LlmHttpOptions {
pub base_url: String,
pub timeout_ms: u64,
pub ceiling: u64,
}
impl Default for LlmHttpOptions {
fn default() -> Self {
Self {
base_url: DEFAULT_URL.into(),
timeout_ms: DEFAULT_TIMEOUT_MS,
ceiling: u64::MAX,
}
}
}
#[derive(Serialize)]
struct InferRequest<'a> {
features: &'a [f64],
baseline_estimate: u64,
}
#[derive(Deserialize)]
struct InferResponse {
estimate: u64,
}
pub struct LlmHttpCorrector {
options: LlmHttpOptions,
}
impl LlmHttpCorrector {
pub fn new(mut options: LlmHttpOptions) -> Self {
if options.timeout_ms > MAX_TIMEOUT_MS {
options.timeout_ms = MAX_TIMEOUT_MS;
}
super::warn_if_remote_plaintext_http(&options.base_url, "llm_http");
Self { options }
}
pub fn with_url(base_url: impl Into<String>) -> Self {
Self::new(LlmHttpOptions {
base_url: base_url.into(),
..LlmHttpOptions::default()
})
}
pub fn options(&self) -> &LlmHttpOptions {
&self.options
}
fn try_infer(&self, features: &CorrectionFeatures) -> Option<u64> {
let feature_vec = features.to_vec();
let payload = InferRequest {
features: &feature_vec,
baseline_estimate: features.baseline_estimate,
};
let timeout = Duration::from_millis(self.options.timeout_ms);
let agent = ureq::AgentBuilder::new()
.timeout_connect(timeout)
.timeout_read(timeout)
.timeout_write(timeout)
.build();
let response = match agent.post(&self.options.base_url).send_json(&payload) {
Ok(r) => r,
Err(err) => {
log::debug!(
"llm_http: request to {} failed: {}",
self.options.base_url,
err
);
return None;
}
};
match response.into_json::<InferResponse>() {
Ok(body) => Some(body.estimate),
Err(err) => {
log::debug!(
"llm_http: response from {} failed to parse: {}",
self.options.base_url,
err
);
None
}
}
}
}
impl Corrector for LlmHttpCorrector {
fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
let Some(raw) = self.try_infer(features) else {
return Ok(None);
};
Ok(Some(saturating_clamp(raw as f64, self.options.ceiling)))
}
fn name(&self) -> &'static str {
"llm-http"
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identity_returns_baseline() {
let corrector = IdentityCorrector;
let features = CorrectionFeatures {
baseline_estimate: 1234,
..Default::default()
};
assert_eq!(corrector.correct(&features).unwrap(), Some(1234));
assert_eq!(corrector.name(), "identity");
}
#[test]
fn tabpfn_stub_always_returns_none() {
let corrector = TabPfnStub;
let features = CorrectionFeatures {
baseline_estimate: 9999,
..Default::default()
};
assert_eq!(
corrector.correct(&features).unwrap(),
None,
"TabPfnStub must always return Ok(None) — it documents the integration point"
);
assert_eq!(corrector.name(), "tabpfn-stub");
let empty = CorrectionFeatures::default();
assert_eq!(corrector.correct(&empty).unwrap(), None);
}
#[test]
fn feature_vec_layout_is_stable() {
let f = CorrectionFeatures {
baseline_estimate: 100,
left_input_rows: Some(10),
right_input_rows: None,
left_distinct: Some(7),
right_distinct: None,
predicate_count: 3,
join_depth: 2,
};
let v = f.to_vec();
assert_eq!(v.len(), CorrectionFeatures::FEATURE_LEN);
assert_eq!(v[0], 100.0);
assert_eq!(v[1], 10.0);
assert_eq!(v[2], 0.0); assert_eq!(v[3], 7.0);
assert_eq!(v[4], 0.0);
assert_eq!(v[5], 3.0);
assert_eq!(v[6], 2.0);
}
}
#[cfg(all(test, feature = "gbt"))]
mod gbt_tests {
use super::gbt::{GbtCorrector, GbtOptions};
use super::{CorrectionFeatures, Corrector};
use crate::feedback::Observation;
fn synthetic_double(n: u64) -> Vec<Observation> {
(1..=n)
.map(|i| Observation {
template_hash: "syn".into(),
plan_fingerprint: "p".into(),
est_rows: i * 10,
actual_rows: i * 10 * 2,
latency_ms: None,
})
.collect()
}
#[test]
fn predicts_roughly_double_when_training_says_double() {
let obs = synthetic_double(200);
let opts = GbtOptions {
learning_rate: 0.3,
max_depth: 4,
num_trees: 50,
ceiling: u64::MAX,
min_leaf_size: 1,
};
let corrector = GbtCorrector::train(&obs, opts).expect("training");
let features = CorrectionFeatures {
baseline_estimate: 500,
..Default::default()
};
let corrected = corrector
.correct(&features)
.expect("correct")
.expect("Some");
let ratio = corrected as f64 / 1000.0;
assert!(
(0.75..=1.25).contains(&ratio),
"expected ~1000, got {} (ratio {})",
corrected,
ratio
);
assert_eq!(corrector.name(), "gbt");
}
#[test]
fn ceiling_clamps_when_prediction_exceeds_it() {
let obs = synthetic_double(200);
let opts = GbtOptions {
learning_rate: 0.3,
max_depth: 4,
num_trees: 50,
ceiling: 100, min_leaf_size: 1,
};
let corrector = GbtCorrector::train(&obs, opts).expect("training");
let features = CorrectionFeatures {
baseline_estimate: 500,
..Default::default()
};
let corrected = corrector
.correct(&features)
.expect("correct")
.expect("Some");
assert_eq!(corrected, 100, "ceiling must clamp the corrected estimate");
assert_eq!(corrector.ceiling(), 100);
}
#[test]
fn empty_observations_errors() {
match GbtCorrector::train(&[], GbtOptions::default()) {
Ok(_) => panic!("expected error on empty observations"),
Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
}
}
#[test]
fn all_zero_observations_errors() {
let obs = vec![
Observation {
template_hash: "z".into(),
plan_fingerprint: "p".into(),
est_rows: 0,
actual_rows: 5,
latency_ms: None,
},
Observation {
template_hash: "z".into(),
plan_fingerprint: "p".into(),
est_rows: 5,
actual_rows: 0,
latency_ms: None,
},
];
match GbtCorrector::train(&obs, GbtOptions::default()) {
Ok(_) => panic!("expected error when all observations are zero"),
Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
}
}
}
#[cfg(all(test, feature = "additive_gbt"))]
mod additive_tests {
use super::additive::{AdditiveGbtCorrector, AdditiveGbtOptions};
use super::{CorrectionFeatures, Corrector};
use crate::feedback::Observation;
fn synthetic_constant(n: u64, target: u64) -> Vec<Observation> {
(1..=n)
.map(|i| Observation {
template_hash: "syn-add".into(),
plan_fingerprint: "p".into(),
est_rows: i * 10,
actual_rows: target,
latency_ms: None,
})
.collect()
}
#[test]
fn predicts_near_constant_when_training_is_constant() {
let obs = synthetic_constant(200, 1000);
let opts = AdditiveGbtOptions {
learning_rate: 0.3,
max_depth: 4,
num_trees: 50,
ceiling: u64::MAX,
min_leaf_size: 1,
};
let corrector =
AdditiveGbtCorrector::train(&obs, opts).expect("training additive corrector");
let features = CorrectionFeatures {
baseline_estimate: 500,
..Default::default()
};
let corrected = corrector
.correct(&features)
.expect("correct")
.expect("Some");
assert!(
(800..=1200).contains(&corrected),
"expected ~1000, got {corrected}"
);
assert_eq!(corrector.name(), "additive_gbt");
}
#[test]
fn ceiling_clamps_when_prediction_exceeds_it() {
let obs = synthetic_constant(200, 1000);
let opts = AdditiveGbtOptions {
learning_rate: 0.3,
max_depth: 4,
num_trees: 50,
ceiling: 100, min_leaf_size: 1,
};
let corrector = AdditiveGbtCorrector::train(&obs, opts).expect("training");
let features = CorrectionFeatures {
baseline_estimate: 500,
..Default::default()
};
let corrected = corrector
.correct(&features)
.expect("correct")
.expect("Some");
assert_eq!(corrected, 100, "ceiling must clamp the additive correction");
assert_eq!(corrector.ceiling(), 100);
}
#[test]
fn corrects_nonzero_even_when_baseline_estimate_is_zero() {
let obs = synthetic_constant(200, 1000);
let corrector =
AdditiveGbtCorrector::train(&obs, AdditiveGbtOptions::default()).expect("training");
let features = CorrectionFeatures {
baseline_estimate: 0,
..Default::default()
};
let corrected = corrector
.correct(&features)
.expect("correct")
.expect("Some");
assert!(
corrected > 0,
"additive corrector must return non-zero even when baseline_estimate = 0; got {corrected}"
);
}
#[test]
fn empty_observations_errors() {
match AdditiveGbtCorrector::train(&[], AdditiveGbtOptions::default()) {
Ok(_) => panic!("expected error on empty observations"),
Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
}
}
}
#[cfg(all(test, feature = "tabpfn_http"))]
mod tabpfn_http_tests {
use super::tabpfn::{TabPfnHttpCorrector, TabPfnHttpOptions};
use super::{CorrectionFeatures, Corrector};
#[test]
fn http_failure_returns_none_not_error() {
let corrector = TabPfnHttpCorrector::new(TabPfnHttpOptions {
base_url: "http://127.0.0.1:1/infer".into(),
timeout_ms: 50,
ceiling: u64::MAX,
});
let features = CorrectionFeatures {
baseline_estimate: 1234,
..Default::default()
};
let result = corrector.correct(&features);
assert!(
result.is_ok(),
"tabpfn-http transport failure must not propagate as Err; got {result:?}"
);
assert_eq!(
result.unwrap(),
None,
"tabpfn-http transport failure must yield Ok(None) so the engine falls back cleanly"
);
assert_eq!(corrector.name(), "tabpfn-http");
}
#[test]
fn malformed_url_returns_none() {
let corrector = TabPfnHttpCorrector::with_url("not a url at all");
let features = CorrectionFeatures::default();
let result = corrector.correct(&features).expect("never Err");
assert_eq!(result, None);
}
#[test]
fn options_default_is_localhost() {
let opts = TabPfnHttpOptions::default();
assert!(opts.base_url.starts_with("http://"));
assert!(opts.timeout_ms > 0);
assert_eq!(opts.ceiling, u64::MAX);
}
}
#[cfg(all(test, feature = "llm_http"))]
mod llm_http_tests {
use super::llm::{
DEFAULT_TIMEOUT_MS, DEFAULT_URL, LlmHttpCorrector, LlmHttpOptions, MAX_TIMEOUT_MS,
};
use super::{CorrectionFeatures, Corrector};
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
fn spawn_mock(
responder: impl Fn(usize) -> Vec<u8> + Send + Sync + 'static,
max_requests: usize,
) -> (String, Arc<AtomicUsize>) {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind loopback");
let port = listener.local_addr().unwrap().port();
let url = format!("http://127.0.0.1:{port}/infer");
let counter = Arc::new(AtomicUsize::new(0));
let counter_thread = Arc::clone(&counter);
let responder = Arc::new(Mutex::new(responder));
thread::spawn(move || {
listener
.set_nonblocking(false)
.expect("blocking mode for mock");
for stream in listener.incoming().take(max_requests) {
let Ok(mut stream) = stream else { continue };
let _ = stream.set_read_timeout(Some(Duration::from_secs(2)));
let _ = stream.set_write_timeout(Some(Duration::from_secs(2)));
let mut buf = [0u8; 4096];
let _ = stream.read(&mut buf);
let idx = counter_thread.fetch_add(1, Ordering::SeqCst);
let body = responder.lock().unwrap()(idx);
let header = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
let _ = stream.write_all(header.as_bytes());
let _ = stream.write_all(&body);
let _ = stream.flush();
}
});
(url, counter)
}
#[test]
fn http_failure_returns_none_not_error() {
let corrector = LlmHttpCorrector::new(LlmHttpOptions {
base_url: "http://127.0.0.1:1/infer".into(),
timeout_ms: 50,
ceiling: u64::MAX,
});
let features = CorrectionFeatures {
baseline_estimate: 1234,
..Default::default()
};
let result = corrector.correct(&features);
assert!(
result.is_ok(),
"llm-http transport failure must not propagate as Err; got {result:?}"
);
assert_eq!(
result.unwrap(),
None,
"llm-http transport failure must yield Ok(None) so the engine falls back cleanly"
);
assert_eq!(corrector.name(), "llm-http");
}
#[test]
fn malformed_url_returns_none() {
let corrector = LlmHttpCorrector::with_url("not a url at all");
let features = CorrectionFeatures::default();
let result = corrector.correct(&features).expect("never Err");
assert_eq!(result, None);
}
#[test]
fn options_default_is_localhost_on_llm_port() {
let opts = LlmHttpOptions::default();
assert_eq!(opts.base_url, DEFAULT_URL);
assert!(opts.base_url.contains(":8766"));
assert_eq!(opts.timeout_ms, DEFAULT_TIMEOUT_MS);
assert_eq!(opts.ceiling, u64::MAX);
}
#[test]
fn timeout_is_saturated_to_max() {
let corrector = LlmHttpCorrector::new(LlmHttpOptions {
base_url: "http://127.0.0.1:1/infer".into(),
timeout_ms: MAX_TIMEOUT_MS * 10,
ceiling: u64::MAX,
});
assert_eq!(corrector.options().timeout_ms, MAX_TIMEOUT_MS);
}
#[test]
fn mock_success_returns_clamped_estimate() {
let (url, counter) = spawn_mock(|_| br#"{"estimate": 4242}"#.to_vec(), 2);
let corrector = LlmHttpCorrector::new(LlmHttpOptions {
base_url: url,
timeout_ms: 2_000,
ceiling: 1_000_000,
});
let features = CorrectionFeatures {
baseline_estimate: 1_000,
..Default::default()
};
let result = corrector.correct(&features).expect("ok");
assert_eq!(result, Some(4242));
assert!(counter.load(Ordering::SeqCst) >= 1);
}
#[test]
fn mock_clamps_to_ceiling() {
let (url, _counter) = spawn_mock(|_| br#"{"estimate": 99999999}"#.to_vec(), 2);
let corrector = LlmHttpCorrector::new(LlmHttpOptions {
base_url: url,
timeout_ms: 2_000,
ceiling: 500,
});
let result = corrector
.correct(&CorrectionFeatures::default())
.expect("ok");
assert_eq!(result, Some(500));
}
#[test]
fn mock_parse_error_returns_none() {
let (url, _counter) = spawn_mock(|_| b"not json at all".to_vec(), 2);
let corrector = LlmHttpCorrector::with_url(url);
let result = corrector
.correct(&CorrectionFeatures::default())
.expect("ok");
assert_eq!(result, None);
}
}